Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cacheing functionality for JWK set #781

Merged
merged 22 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9cebaee
Initial implementation of ttl jwk set cache
wuhaoyujerry Jun 29, 2022
41b6f7f
Add unit test for jwk set cache
wuhaoyujerry Jun 29, 2022
544d7ca
Fix failed unit test
wuhaoyujerry Jun 30, 2022
ac9ae72
Disable cache signing key by default
wuhaoyujerry Jun 30, 2022
a209586
Add a negative unit test for get_jwk_set
wuhaoyujerry Jun 30, 2022
b21d67b
Add functionality to force refresh the jwk set cache when no matching…
wuhaoyujerry Jun 30, 2022
70d1d2f
Add unit test for refresh cache
wuhaoyujerry Jul 1, 2022
d611b8d
Add unit test to unset cache when the network call throws error
wuhaoyujerry Jul 1, 2022
56076f0
fix naming typo
wuhaoyujerry Jul 1, 2022
a4c28d1
Update unit test naming
wuhaoyujerry Jul 1, 2022
e4b29b0
Update comment
wuhaoyujerry Jul 1, 2022
2c1bd08
Add check for lifespan
wuhaoyujerry Jul 11, 2022
913017e
Update comments for get_signing_key
wuhaoyujerry Jul 11, 2022
9ad63c2
Merge pull request #1 from wuhaoyujerry/jwk_set_cache
wuhaoyujerry Jul 11, 2022
a064024
Merge branch 'jpadilla:master' into master
wuhaoyujerry Jul 11, 2022
f7e3cad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2022
3dfe73f
Fix ci error
wuhaoyujerry Jul 11, 2022
8c595b3
Add type declaration to fix CI error
wuhaoyujerry Jul 11, 2022
e5dc7f7
Add more unit tests to improve coverage
wuhaoyujerry Jul 11, 2022
6d91497
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2022
6bb6aa7
Try to increase test coverage to 100%
wuhaoyujerry Jul 11, 2022
1aff672
Merge branch 'jpadilla:master' into master
wuhaoyujerry Jul 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions jwt/api_jwk.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import time

from .algorithms import get_default_algorithms
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
Expand Down Expand Up @@ -108,3 +109,15 @@ def __getitem__(self, kid):
if key.key_id == kid:
return key
raise KeyError(f"keyset has no key for kid: {kid}")


class PyJWTSetWithTimestamp:
def __init__(self, jwk_set: PyJWKSet):
self.jwk_set = jwk_set
self.timestamp = time.monotonic()

def get_jwk_set(self):
return self.jwk_set

def get_timestamp(self):
return self.timestamp
32 changes: 32 additions & 0 deletions jwt/jwk_set_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import time
from typing import Optional

from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp


class JWKSetCache:
def __init__(self, lifespan: int):
self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
self.lifespan = lifespan

def put(self, jwk_set: PyJWKSet):
if jwk_set is not None:
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
else:
# clear cache
self.jwk_set_with_timestamp = None

def get(self) -> Optional[PyJWKSet]:
if self.jwk_set_with_timestamp is None or self.is_expired():
return None

return self.jwk_set_with_timestamp.get_jwk_set()

def is_expired(self) -> bool:

return (
self.jwk_set_with_timestamp is not None
and self.lifespan > -1
and time.monotonic()
> self.jwk_set_with_timestamp.get_timestamp() + self.lifespan
)
82 changes: 65 additions & 17 deletions jwt/jwks_client.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,68 @@
import json
import urllib.request
from functools import lru_cache
from typing import Any, List
from typing import Any, List, Optional
from urllib.error import URLError

from .api_jwk import PyJWK, PyJWKSet
from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientError
from .jwk_set_cache import JWKSetCache


class PyJWKClient:
def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16):
def __init__(
self,
uri: str,
cache_keys: bool = False,
max_cached_keys: int = 16,
cache_jwk_set: bool = True,
lifespan: int = 300,
):
self.uri = uri
self.jwk_set_cache: Optional[JWKSetCache] = None

if cache_jwk_set:
# Init jwt set cache with default or given lifespan.
# Default lifespan is 300 seconds (5 minutes).
if lifespan <= 0:
raise PyJWKClientError(
f'Lifespan must be greater than 0, the input is "{lifespan}"'
)
self.jwk_set_cache = JWKSetCache(lifespan)
else:
self.jwk_set_cache = None

if cache_keys:
# Cache signing keys
# Ignore mypy (https://github.com/python/mypy/issues/2427)
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore

def fetch_data(self) -> Any:
with urllib.request.urlopen(self.uri) as response:
return json.load(response)
jwk_set: Any = None
try:
with urllib.request.urlopen(self.uri) as response:
jwk_set = json.load(response)
except URLError as e:
raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"')
else:
return jwk_set
finally:
if self.jwk_set_cache is not None:
self.jwk_set_cache.put(jwk_set)

def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
data = None
if self.jwk_set_cache is not None and not refresh:
data = self.jwk_set_cache.get()

if data is None:
data = self.fetch_data()

def get_jwk_set(self) -> PyJWKSet:
data = self.fetch_data()
return PyJWKSet.from_dict(data)

def get_signing_keys(self) -> List[PyJWK]:
jwk_set = self.get_jwk_set()
def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
jwk_set = self.get_jwk_set(refresh)
signing_keys = [
jwk_set_key
for jwk_set_key in jwk_set.keys
Expand All @@ -39,21 +76,32 @@ def get_signing_keys(self) -> List[PyJWK]:

def get_signing_key(self, kid: str) -> PyJWK:
signing_keys = self.get_signing_keys()
signing_key = None

for key in signing_keys:
if key.key_id == kid:
signing_key = key
break
signing_key = self.match_kid(signing_keys, kid)

if not signing_key:
raise PyJWKClientError(
f'Unable to find a signing key that matches: "{kid}"'
)
# If no matching signing key from the jwk set, refresh the jwk set and try again.
signing_keys = self.get_signing_keys(refresh=True)
signing_key = self.match_kid(signing_keys, kid)

if not signing_key:
raise PyJWKClientError(
f'Unable to find a signing key that matches: "{kid}"'
)

return signing_key

def get_signing_key_from_jwt(self, token: str) -> PyJWK:
unverified = decode_token(token, options={"verify_signature": False})
header = unverified["header"]
return self.get_signing_key(header.get("kid"))

@staticmethod
def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]:
signing_key = None

for key in signing_keys:
if key.key_id == kid:
signing_key = key
break

return signing_key
Loading