Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Telemetry V4 #329

Merged
merged 1 commit into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 86 additions & 80 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import sys
import warnings
import uuid
from threading import Lock

import requests

Expand All @@ -18,6 +18,7 @@
from .wstrust_request import send_request as wst_send_request
from .wstrust_response import *
from .token_cache import TokenCache
import msal.telemetry


# The __init__.py will import this. Not the other way around.
Expand Down Expand Up @@ -52,18 +53,6 @@ def decorate_scope(
decorated = scope_set | reserved_scope
return list(decorated)

CLIENT_REQUEST_ID = 'client-request-id'
CLIENT_CURRENT_TELEMETRY = 'x-client-current-telemetry'

def _get_new_correlation_id():
correlation_id = str(uuid.uuid4())
logger.debug("Generates correlation_id: %s", correlation_id)
return correlation_id


def _build_current_telemetry_request_header(public_api_id, force_refresh=False):
return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0")


def extract_certs(public_cert_content):
# Parses raw public certificate file contents and returns a list of strings
Expand Down Expand Up @@ -257,6 +246,14 @@ def __init__(
self.token_cache = token_cache or TokenCache()
self.client = self._build_client(client_credential, self.authority)
self.authority_groups = None
self._telemetry_buffer = {}
self._telemetry_lock = Lock()

def _build_telemetry_context(
self, api_id, correlation_id=None, refresh_reason=None):
return msal.telemetry._TelemetryContext(
self._telemetry_buffer, self._telemetry_lock, api_id,
correlation_id=correlation_id, refresh_reason=refresh_reason)

def _build_client(self, client_credential, authority):
client_assertion = None
Expand Down Expand Up @@ -520,21 +517,21 @@ def authorize(): # A controller in a web app
return redirect(url_for("index"))
"""
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return _clean_up(self.client.obtain_token_by_auth_code_flow(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID)
response =_clean_up(self.client.obtain_token_by_auth_code_flow(
auth_code_flow,
auth_response,
scope=decorate_scope(scopes, self.client_id) if scopes else None,
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID),
},
headers=telemetry_context.generate_headers(),
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities,
auth_code_flow.pop("claims_challenge", None))),
**kwargs))
telemetry_context.update_telemetry(response)
return response

def acquire_token_by_authorization_code(
self,
Expand Down Expand Up @@ -593,20 +590,20 @@ def acquire_token_by_authorization_code(
"Change your acquire_token_by_authorization_code() "
"to acquire_token_by_auth_code_flow()", DeprecationWarning)
with warnings.catch_warnings(record=True):
return _clean_up(self.client.obtain_token_by_authorization_code(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID)
response = _clean_up(self.client.obtain_token_by_authorization_code(
code, redirect_uri=redirect_uri,
scope=decorate_scope(scopes, self.client_id),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID),
},
headers=telemetry_context.generate_headers(),
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
nonce=nonce,
**kwargs))
telemetry_context.update_telemetry(resposne)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A typo sneaked in. It will be fixed in PR #345

return response

def get_accounts(self, username=None):
"""Get a list of accounts which previously signed in, i.e. exists in cache.
Expand Down Expand Up @@ -735,7 +732,7 @@ def acquire_token_silent(
- None when cache lookup does not yield a token.
"""
result = self.acquire_token_silent_with_error(
scopes, account, authority, force_refresh,
scopes, account, authority=authority, force_refresh=force_refresh,
claims_challenge=claims_challenge, **kwargs)
return result if result and "error" not in result else None

Expand Down Expand Up @@ -780,7 +777,7 @@ def acquire_token_silent_with_error(
"""
assert isinstance(scopes, list), "Invalid parameter type"
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
correlation_id = _get_new_correlation_id()
correlation_id = msal.telemetry._get_new_correlation_id()
if authority:
warnings.warn("We haven't decided how/if this method will accept authority parameter")
# the_authority = Authority(
Expand Down Expand Up @@ -851,9 +848,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
target=scopes,
query=query)
now = time.time()
refresh_reason = msal.telemetry.AT_ABSENT
for entry in matches:
expires_in = int(entry["expires_on"]) - now
if expires_in < 5*60: # Then consider it expired
refresh_reason = msal.telemetry.AT_EXPIRED
continue # Removal is not necessary, it will be overwritten
logger.debug("Cache hit an AT")
access_token_from_cache = { # Mimic a real response
Expand All @@ -862,13 +861,18 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
"expires_in": int(expires_in), # OAuth2 specs defines it as int
}
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
refresh_reason = msal.telemetry.AT_AGING
break # With a fallback in hand, we break here to go refresh
self._build_telemetry_context(-1).hit_an_access_token()
return access_token_from_cache # It is still good as new
else:
refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge
assert refresh_reason, "It should have been established at this point"
try:
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, decorate_scope(scopes, self.client_id), account,
force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs)
result = _clean_up(result)
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
**kwargs))
if (result and "error" not in result) or (not access_token_from_cache):
return result
except: # The exact HTTP exception is transportation-layer dependent
Expand Down Expand Up @@ -922,7 +926,8 @@ def _get_app_metadata(self, environment):
def _acquire_token_silent_by_finding_specific_refresh_token(
self, authority, scopes, query,
rt_remover=None, break_condition=lambda response: False,
force_refresh=False, correlation_id=None, claims_challenge=None, **kwargs):
refresh_reason=None, correlation_id=None, claims_challenge=None,
**kwargs):
matches = self.token_cache.find(
self.token_cache.CredentialType.REFRESH_TOKEN,
# target=scopes, # AAD RTs are scope-independent
Expand All @@ -931,6 +936,9 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
client = self._build_client(self.client_credential, authority)

response = None # A distinguishable value to mean cache is empty
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_SILENT_ID,
correlation_id=correlation_id, refresh_reason=refresh_reason)
for entry in sorted( # Since unfit RTs would not be aggressively removed,
# we start from newer RTs which are more likely fit.
matches,
Expand All @@ -948,16 +956,13 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
skip_account_creation=True, # To honor a concurrent remove_account()
)),
scope=scopes,
headers={
CLIENT_REQUEST_ID: correlation_id or _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_SILENT_ID, force_refresh=force_refresh),
},
headers=telemetry_context.generate_headers(),
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
**kwargs)
telemetry_context.update_telemetry(response)
if "error" not in response:
return response
logger.debug("Refresh failed. {error}: {error_description}".format(
Expand Down Expand Up @@ -1006,18 +1011,19 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
* A dict contains no "error" key means migration was successful.
"""
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return _clean_up(self.client.obtain_token_by_refresh_token(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_BY_REFRESH_TOKEN,
refresh_reason=msal.telemetry.FORCE_REFRESH)
response = _clean_up(self.client.obtain_token_by_refresh_token(
refresh_token,
scope=decorate_scope(scopes, self.client_id),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_REFRESH_TOKEN),
},
headers=telemetry_context.generate_headers(),
rt_getter=lambda rt: rt,
on_updating_rt=False,
on_removing_rt=lambda rt_item: None, # No OP
**kwargs))
telemetry_context.update_telemetry(response)
return response


class PublicClientApplication(ClientApplication): # browser app or mobile app
Expand Down Expand Up @@ -1093,7 +1099,9 @@ def acquire_token_interactive(
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
claims = _merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)
return _clean_up(self.client.obtain_token_by_browser(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_INTERACTIVE)
response = _clean_up(self.client.obtain_token_by_browser(
scope=decorate_scope(scopes, self.client_id) if scopes else None,
extra_scope_to_consent=extra_scopes_to_consent,
redirect_uri="http://localhost:{port}".format(
Expand All @@ -1107,12 +1115,10 @@ def acquire_token_interactive(
"domain_hint": domain_hint,
},
data=dict(kwargs.pop("data", {}), claims=claims),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_INTERACTIVE),
},
headers=telemetry_context.generate_headers(),
**kwargs))
telemetry_context.update_telemetry(response)
return response

def initiate_device_flow(self, scopes=None, **kwargs):
"""Initiate a Device Flow instance,
Expand All @@ -1125,13 +1131,10 @@ def initiate_device_flow(self, scopes=None, **kwargs):
- A successful response would contain "user_code" key, among others
- an error response would contain some other readable key/value pairs.
"""
correlation_id = _get_new_correlation_id()
correlation_id = msal.telemetry._get_new_correlation_id()
flow = self.client.initiate_device_flow(
scope=decorate_scope(scopes or [], self.client_id),
headers={
CLIENT_REQUEST_ID: correlation_id,
# CLIENT_CURRENT_TELEMETRY is not currently required
},
headers={msal.telemetry.CLIENT_REQUEST_ID: correlation_id},
**kwargs)
flow[self.DEVICE_FLOW_CORRELATION_ID] = correlation_id
return flow
Expand All @@ -1155,7 +1158,10 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
- A successful response would contain "access_token" key,
- an error response would contain "error" and usually "error_description".
"""
return _clean_up(self.client.obtain_token_by_device_flow(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID,
correlation_id=flow.get(self.DEVICE_FLOW_CORRELATION_ID))
response = _clean_up(self.client.obtain_token_by_device_flow(
flow,
data=dict(
kwargs.pop("data", {}),
Expand All @@ -1165,13 +1171,10 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge),
),
headers={
CLIENT_REQUEST_ID:
flow.get(self.DEVICE_FLOW_CORRELATION_ID) or _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID),
},
headers=telemetry_context.generate_headers(),
**kwargs))
telemetry_context.update_telemetry(response)
return response

def acquire_token_by_username_password(
self, username, password, scopes, claims_challenge=None, **kwargs):
Expand All @@ -1196,28 +1199,30 @@ def acquire_token_by_username_password(
- an error response would contain "error" and usually "error_description".
"""
scopes = decorate_scope(scopes, self.client_id)
headers = {
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID),
}
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID)
headers = telemetry_context.generate_headers()
data = dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge))
if not self.authority.is_adfs:
user_realm_result = self.authority.user_realm_discovery(
username, correlation_id=headers[CLIENT_REQUEST_ID])
username, correlation_id=headers[msal.telemetry.CLIENT_REQUEST_ID])
if user_realm_result.get("account_type") == "Federated":
return _clean_up(self._acquire_token_by_username_password_federated(
response = _clean_up(self._acquire_token_by_username_password_federated(
user_realm_result, username, password, scopes=scopes,
data=data,
headers=headers, **kwargs))
return _clean_up(self.client.obtain_token_by_username_password(
telemetry_context.update_telemetry(response)
return response
response = _clean_up(self.client.obtain_token_by_username_password(
username, password, scope=scopes,
headers=headers,
data=data,
**kwargs))
telemetry_context.update_telemetry(response)
return response

def _acquire_token_by_username_password_federated(
self, user_realm_result, username, password, scopes=None, **kwargs):
Expand Down Expand Up @@ -1277,18 +1282,18 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
"""
# TBD: force_refresh behavior
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return _clean_up(self.client.obtain_token_for_client(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_FOR_CLIENT_ID)
response = _clean_up(self.client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_FOR_CLIENT_ID),
},
headers=telemetry_context.generate_headers(),
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
**kwargs))
telemetry_context.update_telemetry(response)
return response

def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs):
"""Acquires token using on-behalf-of (OBO) flow.
Expand Down Expand Up @@ -1316,9 +1321,11 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
- A successful response would contain "access_token" key,
- an error response would contain "error" and usually "error_description".
"""
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID)
# The implementation is NOT based on Token Exchange
# https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16
return _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
response = _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
user_assertion,
self.client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs
scope=decorate_scope(scopes, self.client_id), # Decoration is used for:
Expand All @@ -1332,9 +1339,8 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
requested_token_use="on_behalf_of",
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID),
},
headers=telemetry_context.generate_headers(),
**kwargs))
telemetry_context.update_telemetry(response)
return response

Loading