Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,49 @@ def _get_authority_aliases(self, instance):
return [alias for alias in group if alias != instance]
return []

def remove_account(self, account):
"""Sign me out and forget me from token cache"""
self._forget_me(account)

def _sign_out(self, home_account):
# Remove all relevant RTs and ATs from token cache
owned_by_home_account = {
"environment": home_account["environment"],
"home_account_id": home_account["home_account_id"],} # realm-independent
app_metadata = self._get_app_metadata(home_account["environment"])
# Remove RTs/FRTs, and they are realm-independent
for rt in [rt for rt in self.token_cache.find(
TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_home_account)
# Do RT's app ownership check as a precaution, in case family apps
# and 3rd-party apps share same token cache, although they should not.
if rt["client_id"] == self.client_id or (
app_metadata.get("family_id") # Now let's settle family business
and rt.get("family_id") == app_metadata["family_id"])
]:
self.token_cache.remove_rt(rt)
for at in self.token_cache.find( # Remove ATs
# Regardless of realm, b/c we've removed realm-independent RTs anyway
TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account):
# To avoid the complexity of locating sibling family app's AT,
# we skip AT's app ownership check.
# It means ATs for other apps will also be removed, it is OK because:
# * non-family apps are not supposed to share token cache to begin with;
# * Even if it happens, we keep other app's RT already, so SSO still works
self.token_cache.remove_at(at)

def _forget_me(self, home_account):
# It implies signout, and then also remove all relevant accounts and IDTs
self._sign_out(home_account)
owned_by_home_account = {
"environment": home_account["environment"],
"home_account_id": home_account["home_account_id"],} # realm-independent
for idt in self.token_cache.find( # Remove IDTs, regardless of realm
TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account):
self.token_cache.remove_idt(idt)
for a in self.token_cache.find( # Remove Accounts, regardless of realm
TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account):
self.token_cache.remove_account(a)

def acquire_token_silent(
self,
scopes, # type: List[str]
Expand Down Expand Up @@ -364,10 +407,7 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
"home_account_id": (account or {}).get("home_account_id"),
# "realm": authority.tenant, # AAD RTs are tenant-independent
}
apps = self.token_cache.find( # Use find(), rather than token_cache.get(...)
TokenCache.CredentialType.APP_METADATA, query={
"environment": authority.instance, "client_id": self.client_id})
app_metadata = apps[0] if apps else {}
app_metadata = self._get_app_metadata(authority.instance)
if not app_metadata: # Meaning this app is now used for the first time.
# When/if we have a way to directly detect current app's family,
# we'll rewrite this block, to support multiple families.
Expand Down Expand Up @@ -396,6 +436,12 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
return self._acquire_token_silent_by_finding_specific_refresh_token(
authority, scopes, dict(query, client_id=self.client_id), **kwargs)

def _get_app_metadata(self, environment):
apps = self.token_cache.find( # Use find(), rather than token_cache.get(...)
TokenCache.CredentialType.APP_METADATA, query={
"environment": environment, "client_id": self.client_id})
return apps[0] if apps else {}

def _acquire_token_silent_by_finding_specific_refresh_token(
self, authority, scopes, query,
rt_remover=None, break_condition=lambda response: False, **kwargs):
Expand Down
125 changes: 88 additions & 37 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class AuthorityType:
def __init__(self):
self._lock = threading.RLock()
self._cache = {}
self.key_makers = {
self.CredentialType.REFRESH_TOKEN: self._build_rt_key,
self.CredentialType.ACCESS_TOKEN: self._build_at_key,
self.CredentialType.ID_TOKEN: self._build_idt_key,
self.CredentialType.ACCOUNT: self._build_account_key,
}

def find(self, credential_type, target=None, query=None):
target = target or []
Expand Down Expand Up @@ -83,14 +89,9 @@ def add(self, event, now=None):
with self._lock:

if access_token:
key = "-".join([
home_account_id or "",
environment or "",
self.CredentialType.ACCESS_TOKEN,
event.get("client_id", ""),
realm or "",
target,
]).lower()
key = self._build_at_key(
home_account_id, environment, event.get("client_id", ""),
realm, target)
now = time.time() if now is None else now
expires_in = response.get("expires_in", 3599)
self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = {
Expand All @@ -110,11 +111,7 @@ def add(self, event, now=None):
if client_info:
decoded_id_token = json.loads(
base64decode(id_token.split('.')[1])) if id_token else {}
key = "-".join([
home_account_id or "",
environment or "",
realm or "",
]).lower()
key = self._build_account_key(home_account_id, environment, realm)
self._cache.setdefault(self.CredentialType.ACCOUNT, {})[key] = {
"home_account_id": home_account_id,
"environment": environment,
Expand All @@ -129,14 +126,8 @@ def add(self, event, now=None):
}

if id_token:
key = "-".join([
home_account_id or "",
environment or "",
self.CredentialType.ID_TOKEN,
event.get("client_id", ""),
realm or "",
"" # Albeit irrelevant, schema requires an empty scope here
]).lower()
key = self._build_idt_key(
home_account_id, environment, event.get("client_id", ""), realm)
self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = {
"credential_type": self.CredentialType.ID_TOKEN,
"secret": id_token,
Expand Down Expand Up @@ -170,6 +161,24 @@ def add(self, event, now=None):
"family_id": response.get("foci"), # None is also valid
}

def modify(self, credential_type, old_entry, new_key_value_pairs=None):
# Modify the specified old_entry with new_key_value_pairs,
# or remove the old_entry if the new_key_value_pairs is None.

# This helper exists to consolidate all token modify/remove behaviors,
# so that the sub-classes will have only one method to work on,
# instead of patching a pair of update_xx() and remove_xx() per type.
# You can monkeypatch self.key_makers to support more types on-the-fly.
key = self.key_makers[credential_type](**old_entry)
with self._lock:
if new_key_value_pairs: # Update with them
entries = self._cache.setdefault(credential_type, {})
entry = entries.get(key, {}) # key usually exists, but we'll survive its absence
entry.update(new_key_value_pairs)
else: # Remove old_entry
self._cache.setdefault(credential_type, {}).pop(key, None)


@staticmethod
def _build_appmetadata_key(environment, client_id):
return "appmetadata-{}-{}".format(environment or "", client_id or "")
Expand All @@ -178,7 +187,7 @@ def _build_appmetadata_key(environment, client_id):
def _build_rt_key(
cls,
home_account_id=None, environment=None, client_id=None, target=None,
**ignored):
**ignored_payload_from_a_real_token):
return "-".join([
home_account_id or "",
environment or "",
Expand All @@ -189,16 +198,61 @@ def _build_rt_key(
]).lower()

def remove_rt(self, rt_item):
key = self._build_rt_key(**rt_item)
with self._lock:
self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}).pop(key, None)
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item)

def update_rt(self, rt_item, new_rt):
key = self._build_rt_key(**rt_item)
with self._lock:
RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})
rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence
rt["secret"] = new_rt
assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
return self.modify(
self.CredentialType.REFRESH_TOKEN, rt_item, {"secret": new_rt})

@classmethod
def _build_at_key(cls,
home_account_id=None, environment=None, client_id=None,
realm=None, target=None, **ignored_payload_from_a_real_token):
return "-".join([
home_account_id or "",
environment or "",
cls.CredentialType.ACCESS_TOKEN,
client_id,
realm or "",
target or "",
]).lower()

def remove_at(self, at_item):
assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN
return self.modify(self.CredentialType.ACCESS_TOKEN, at_item)

@classmethod
def _build_idt_key(cls,
home_account_id=None, environment=None, client_id=None, realm=None,
**ignored_payload_from_a_real_token):
return "-".join([
home_account_id or "",
environment or "",
cls.CredentialType.ID_TOKEN,
client_id or "",
realm or "",
"" # Albeit irrelevant, schema requires an empty scope here
]).lower()

def remove_idt(self, idt_item):
assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN
return self.modify(self.CredentialType.ID_TOKEN, idt_item)

@classmethod
def _build_account_key(cls,
home_account_id=None, environment=None, realm=None,
**ignored_payload_from_a_real_entry):
return "-".join([
home_account_id or "",
environment or "",
realm or "",
]).lower()

def remove_account(self, account_item):
assert "authority_type" in account_item
return self.modify(self.CredentialType.ACCOUNT, account_item)


class SerializableTokenCache(TokenCache):
Expand All @@ -221,7 +275,7 @@ class SerializableTokenCache(TokenCache):
...

:var bool has_state_changed:
Indicates whether the cache state has changed since last
Indicates whether the cache state in the memory has changed since last
:func:`~serialize` or :func:`~deserialize` call.
"""
has_state_changed = False
Expand All @@ -230,12 +284,9 @@ def add(self, event, **kwargs):
super(SerializableTokenCache, self).add(event, **kwargs)
self.has_state_changed = True

def remove_rt(self, rt_item):
super(SerializableTokenCache, self).remove_rt(rt_item)
self.has_state_changed = True

def update_rt(self, rt_item, new_rt):
super(SerializableTokenCache, self).update_rt(rt_item, new_rt)
def modify(self, credential_type, old_entry, new_key_value_pairs=None):
super(SerializableTokenCache, self).modify(
credential_type, old_entry, new_key_value_pairs)
self.has_state_changed = True

def deserialize(self, state):
Expand Down
34 changes: 33 additions & 1 deletion tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,14 @@ def setUp(self):
self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)}
self.frt = "what the frt"
self.cache = msal.SerializableTokenCache()
self.preexisting_family_app_id = "preexisting_family_app"
self.cache.add({ # Pre-populate a FRT
"client_id": "preexisting_family_app",
"client_id": self.preexisting_family_app_id,
"scope": self.scopes,
"token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url),
"response": TokenCacheTestCase.build_response(
access_token="Siblings won't share AT. test_remove_account() will.",
id_token=TokenCacheTestCase.build_id_token(),
uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"),
}) # The add(...) helper populates correct home_account_id for future searching

Expand Down Expand Up @@ -239,6 +242,35 @@ def tester(url, data=None, **kwargs):

# Will not test scenario of app leaving family. Per specs, it won't happen.

def test_family_app_remove_account(self):
logger.debug("%s.cache = %s", self.id(), self.cache.serialize())
app = ClientApplication(
self.preexisting_family_app_id,
authority=self.authority_url, token_cache=self.cache)
account = app.get_accounts()[0]
mine = {"home_account_id": account["home_account_id"]}

self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.ACCESS_TOKEN, query=mine))
self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.REFRESH_TOKEN, query=mine))
self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.ID_TOKEN, query=mine))
self.assertNotEqual([], self.cache.find(
self.cache.CredentialType.ACCOUNT, query=mine))

app.remove_account(account)

self.assertEqual([], self.cache.find(
self.cache.CredentialType.ACCESS_TOKEN, query=mine))
self.assertEqual([], self.cache.find(
self.cache.CredentialType.REFRESH_TOKEN, query=mine))
self.assertEqual([], self.cache.find(
self.cache.CredentialType.ID_TOKEN, query=mine))
self.assertEqual([], self.cache.find(
self.cache.CredentialType.ACCOUNT, query=mine))


class TestClientApplicationForAuthorityMigration(unittest.TestCase):

@classmethod
Expand Down