Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


# The __init__.py will import this. Not the other way around.
__version__ = "0.3.1"
__version__ = "0.4.0"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -280,6 +280,49 @@ def _get_authority_aliases(self, instance):
return [alias for alias in group if alias != instance]
return []

def remove_account(self, account):
"""Sign me out and forget me from token cache"""
self._forget_me(account)

def _sign_out(self, home_account):
# Remove all relevant RTs and ATs from token cache
owned_by_home_account = {
"environment": home_account["environment"],
"home_account_id": home_account["home_account_id"],} # realm-independent
app_metadata = self._get_app_metadata(home_account["environment"])
# Remove RTs/FRTs, and they are realm-independent
for rt in [rt for rt in self.token_cache.find(
TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_home_account)
# Do RT's app ownership check as a precaution, in case family apps
# and 3rd-party apps share same token cache, although they should not.
if rt["client_id"] == self.client_id or (
app_metadata.get("family_id") # Now let's settle family business
and rt.get("family_id") == app_metadata["family_id"])
]:
self.token_cache.remove_rt(rt)
for at in self.token_cache.find( # Remove ATs
# Regardless of realm, b/c we've removed realm-independent RTs anyway
TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account):
# To avoid the complexity of locating sibling family app's AT,
# we skip AT's app ownership check.
# It means ATs for other apps will also be removed, it is OK because:
# * non-family apps are not supposed to share token cache to begin with;
# * Even if it happens, we keep other app's RT already, so SSO still works
self.token_cache.remove_at(at)

def _forget_me(self, home_account):
# It implies signout, and then also remove all relevant accounts and IDTs
self._sign_out(home_account)
owned_by_home_account = {
"environment": home_account["environment"],
"home_account_id": home_account["home_account_id"],} # realm-independent
for idt in self.token_cache.find( # Remove IDTs, regardless of realm
TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account):
self.token_cache.remove_idt(idt)
for a in self.token_cache.find( # Remove Accounts, regardless of realm
TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account):
self.token_cache.remove_account(a)

def acquire_token_silent(
self,
scopes, # type: List[str]
Expand Down Expand Up @@ -364,10 +407,7 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
"home_account_id": (account or {}).get("home_account_id"),
# "realm": authority.tenant, # AAD RTs are tenant-independent
}
apps = self.token_cache.find( # Use find(), rather than token_cache.get(...)
TokenCache.CredentialType.APP_METADATA, query={
"environment": authority.instance, "client_id": self.client_id})
app_metadata = apps[0] if apps else {}
app_metadata = self._get_app_metadata(authority.instance)
if not app_metadata: # Meaning this app is now used for the first time.
# When/if we have a way to directly detect current app's family,
# we'll rewrite this block, to support multiple families.
Expand Down Expand Up @@ -396,6 +436,12 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
return self._acquire_token_silent_by_finding_specific_refresh_token(
authority, scopes, dict(query, client_id=self.client_id), **kwargs)

def _get_app_metadata(self, environment):
apps = self.token_cache.find( # Use find(), rather than token_cache.get(...)
TokenCache.CredentialType.APP_METADATA, query={
"environment": environment, "client_id": self.client_id})
return apps[0] if apps else {}

def _acquire_token_silent_by_finding_specific_refresh_token(
self, authority, scopes, query,
rt_remover=None, break_condition=lambda response: False, **kwargs):
Expand Down
125 changes: 88 additions & 37 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class AuthorityType:
def __init__(self):
self._lock = threading.RLock()
self._cache = {}
self.key_makers = {
self.CredentialType.REFRESH_TOKEN: self._build_rt_key,
self.CredentialType.ACCESS_TOKEN: self._build_at_key,
self.CredentialType.ID_TOKEN: self._build_idt_key,
self.CredentialType.ACCOUNT: self._build_account_key,
}

def find(self, credential_type, target=None, query=None):
target = target or []
Expand Down Expand Up @@ -83,14 +89,9 @@ def add(self, event, now=None):
with self._lock:

if access_token:
key = "-".join([
home_account_id or "",
environment or "",
self.CredentialType.ACCESS_TOKEN,
event.get("client_id", ""),
realm or "",
target,
]).lower()
key = self._build_at_key(
home_account_id, environment, event.get("client_id", ""),
realm, target)
now = time.time() if now is None else now
expires_in = response.get("expires_in", 3599)
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = {
Expand All @@ -110,11 +111,7 @@ def add(self, event, now=None):
if client_info:
decoded_id_token = json.loads(
base64decode(id_token.split('.')[1])) if id_token else {}
key = "-".join([
home_account_id or "",
environment or "",
realm or "",
]).lower()
key = self._build_account_key(home_account_id, environment, realm)
self._cache.setdefault(self.CredentialType.ACCOUNT, {})[key] = {
"home_account_id": home_account_id,
"environment": environment,
Expand All @@ -129,14 +126,8 @@ def add(self, event, now=None):
}

if id_token:
key = "-".join([
home_account_id or "",
environment or "",
self.CredentialType.ID_TOKEN,
event.get("client_id", ""),
realm or "",
"" # Albeit irrelevant, schema requires an empty scope here
]).lower()
key = self._build_idt_key(
home_account_id, environment, event.get("client_id", ""), realm)
self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = {
"credential_type": self.CredentialType.ID_TOKEN,
"secret": id_token,
Expand Down Expand Up @@ -170,6 +161,24 @@ def add(self, event, now=None):
"family_id": response.get("foci"), # None is also valid
}

def modify(self, credential_type, old_entry, new_key_value_pairs=None):
# Modify the specified old_entry with new_key_value_pairs,
# or remove the old_entry if the new_key_value_pairs is None.

# This helper exists to consolidate all token modify/remove behaviors,
# so that the sub-classes will have only one method to work on,
# instead of patching a pair of update_xx() and remove_xx() per type.
# You can monkeypatch self.key_makers to support more types on-the-fly.
key = self.key_makers[credential_type](**old_entry)
with self._lock:
if new_key_value_pairs: # Update with them
entries = self._cache.setdefault(credential_type, {})
entry = entries.get(key, {}) # key usually exists, but we'll survive its absence
entry.update(new_key_value_pairs)
else: # Remove old_entry
self._cache.setdefault(credential_type, {}).pop(key, None)


@staticmethod
def _build_appmetadata_key(environment, client_id):
return "appmetadata-{}-{}".format(environment or "", client_id or "")
Expand All @@ -178,7 +187,7 @@ def _build_appmetadata_key(environment, client_id):
def _build_rt_key(
cls,
home_account_id=None, environment=None, client_id=None, target=None,
**ignored):
**ignored_payload_from_a_real_token):
return "-".join([
home_account_id or "",
environment or "",
Expand All @@ -189,16 +198,61 @@ def _build_rt_key(
]).lower()

def remove_rt(self, rt_item):
key = self._build_rt_key(**rt_item)
with self._lock:
self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}).pop(key, None)
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item)

def update_rt(self, rt_item, new_rt):
key = self._build_rt_key(**rt_item)
with self._lock:
RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})
rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence
rt["secret"] = new_rt
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
return self.modify(
self.CredentialType.REFRESH_TOKEN, rt_item, {"secret": new_rt})

@classmethod
def _build_at_key(cls,
home_account_id=None, environment=None, client_id=None,
realm=None, target=None, **ignored_payload_from_a_real_token):
return "-".join([
home_account_id or "",
environment or "",
cls.CredentialType.ACCESS_TOKEN,
client_id,
realm or "",
target or "",
]).lower()

def remove_at(self, at_item):
assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN
return self.modify(self.CredentialType.ACCESS_TOKEN, at_item)

@classmethod
def _build_idt_key(cls,
home_account_id=None, environment=None, client_id=None, realm=None,
**ignored_payload_from_a_real_token):
return "-".join([
home_account_id or "",
environment or "",
cls.CredentialType.ID_TOKEN,
client_id or "",
realm or "",
"" # Albeit irrelevant, schema requires an empty scope here
]).lower()

def remove_idt(self, idt_item):
assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN
return self.modify(self.CredentialType.ID_TOKEN, idt_item)

@classmethod
def _build_account_key(cls,
home_account_id=None, environment=None, realm=None,
**ignored_payload_from_a_real_entry):
return "-".join([
home_account_id or "",
environment or "",
realm or "",
]).lower()

def remove_account(self, account_item):
assert "authority_type" in account_item
return self.modify(self.CredentialType.ACCOUNT, account_item)


class SerializableTokenCache(TokenCache):
Expand All @@ -221,7 +275,7 @@ class SerializableTokenCache(TokenCache):
...

:var bool has_state_changed:
Indicates whether the cache state has changed since last
Indicates whether the cache state in the memory has changed since last
:func:`~serialize` or :func:`~deserialize` call.
"""
has_state_changed = False
Expand All @@ -230,12 +284,9 @@ def add(self, event, **kwargs):
super(SerializableTokenCache, self).add(event, **kwargs)
self.has_state_changed = True

def remove_rt(self, rt_item):
super(SerializableTokenCache, self).remove_rt(rt_item)
self.has_state_changed = True

def update_rt(self, rt_item, new_rt):
super(SerializableTokenCache, self).update_rt(rt_item, new_rt)
def modify(self, credential_type, old_entry, new_key_value_pairs=None):
super(SerializableTokenCache, self).modify(
credential_type, old_entry, new_key_value_pairs)
self.has_state_changed = True

def deserialize(self, state):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
The configuration file would look like this:

{
"authority": "https://login.microsoftonline.com/organizations",
"client_id": "your_client_id",
"scope": ["https://graph.microsoft.com/.default"],
"redirect_uri": "http://localhost:5000/getAToken",
// Configure this redirect uri for this sample
// redirect_uri should match what you've configured in here
// https://docs.microsoft.com/en-us/azure/active-directory/develop/quickstart-configure-app-access-web-apis#add-redirect-uris-to-your-application
"client_secret": "yoursecret"
}

You can then run this sample with a JSON configuration file:
python sample.py parameters.json
On the browser open http://localhost:5000/

"""

import sys # For simplicity, we'll read config file from 1st CLI param sys.argv[1]
import json
import logging
import uuid

import flask

import msal

app = flask.Flask(__name__)
app.debug = True
app.secret_key = 'development'


# Optional logging
# logging.basicConfig(level=logging.DEBUG)

config = json.load(open(sys.argv[1]))

application = msal.ConfidentialClientApplication(
config["client_id"], authority=config["authority"],
client_credential=config["client_secret"],
# token_cache=... # Default cache is in memory only.
# You can learn how to use SerializableTokenCache from
# https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache
)


@app.route("/")
def main():
resp = flask.Response(status=307)
resp.headers['location'] = '/login'
return resp


@app.route("/login")
def login():
auth_state = str(uuid.uuid4())
flask.session['state'] = auth_state
authorization_url = application.get_authorization_request_url(config['scope'], state=auth_state,
redirect_uri=config['redirect_uri'])
resp = flask.Response(status=307)
resp.headers['location'] = authorization_url
return resp


@app.route("/getAToken")
def main_logic():
code = flask.request.args['code']
state = flask.request.args['state']
if state != flask.session['state']:
raise ValueError("State does not match")

result = application.acquire_token_by_authorization_code(code, scopes=config["scope"],
redirect_uri=config['redirect_uri'])
return flask.render_template('display.html', auth_result=result)


if __name__ == "__main__":
app.run()
19 changes: 19 additions & 0 deletions sample/authorization-code-flow-sample/templates/display.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Acquire Token Result </title>
</head>
<body>
<p1><b>Acquire Token Result</b> </p1>
<table>
{% for key, value in auth_result.items() %}
<tr>
<th> {{ key }} </th>
<td> {{ value }} </td>
</tr>
{% endfor %}
</table>

</body>
</html>
Loading