Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyJWT async #956

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .tool-versions
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python 3.8.10
1 change: 1 addition & 0 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ Retrieve RSA signing keys from a JWKS endpoint
>>> optional_custom_headers = {"User-agent": "custom-user-agent"}
>>> jwks_client = PyJWKClient(url, headers=optional_custom_headers)
>>> signing_key = jwks_client.get_signing_key_from_jwt(token)
>>> # signing_key = await jwks_client.get_siging_key_from_jwt_async(token) # if calling from an async context
>>> data = jwt.decode(
... token,
... signing_key,
Expand Down
2 changes: 1 addition & 1 deletion jwt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from .jwks_client import PyJWKClient

__version__ = "2.8.0"
__version__ = "2.9.0"

__title__ = "PyJWT"
__description__ = "JSON Web Token implementation in Python"
Expand Down
4 changes: 4 additions & 0 deletions jwt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,7 @@ class PyJWKClientError(PyJWTError):

class PyJWKClientConnectionError(PyJWKClientError):
pass


class PyJWKAsyncDisabledError(PyJWKClientError):
pass
96 changes: 95 additions & 1 deletion jwt/jwks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,21 @@

from .api_jwk import PyJWK, PyJWKSet
from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientConnectionError, PyJWKClientError
from .exceptions import (
PyJWKAsyncDisabledError,
PyJWKClientConnectionError,
PyJWKClientError,
)
from .jwk_set_cache import JWKSetCache

try:
from async_lru import alru_cache
from httpx import AsyncClient, HTTPError, Timeout

has_async = True
except ModuleNotFoundError:
has_async = False


class PyJWKClient:
def __init__(
Expand Down Expand Up @@ -46,6 +58,8 @@ def __init__(
# Cache signing keys
# Ignore mypy (https://github.com/python/mypy/issues/2427)
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore
if has_async:
self.get_signing_key_async = alru_cache(maxsize=max_cached_keys)(self.get_signing_key_async) # type: ignore

def fetch_data(self) -> Any:
jwk_set: Any = None
Expand All @@ -65,6 +79,31 @@ def fetch_data(self) -> Any:
if self.jwk_set_cache is not None:
self.jwk_set_cache.put(jwk_set)

async def fetch_data_async(self) -> Any:
if not has_async:
raise PyJWKAsyncDisabledError()
jwk_set: Any = None
try:
async with AsyncClient(
timeout=Timeout(self.timeout),
verify=self.ssl_context,
) as client:
response = await client.get(
url=self.uri,
headers=self.headers,
)
response.raise_for_status() # Raise an exception for HTTP errors
jwk_set = response.json()
except HTTPError as e:
raise PyJWKClientConnectionError(
f'Fail to fetch data from the url, err: "{e}"'
)
else:
return jwk_set
finally:
if self.jwk_set_cache is not None:
self.jwk_set_cache.put(jwk_set)

def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
data = None
if self.jwk_set_cache is not None and not refresh:
Expand All @@ -78,6 +117,21 @@ def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:

return PyJWKSet.from_dict(data)

async def get_jwk_set_async(self, refresh: bool = False) -> PyJWKSet:
if not has_async:
raise PyJWKAsyncDisabledError()
data = None
if self.jwk_set_cache is not None and not refresh:
data = self.jwk_set_cache.get()

if data is None:
data = await self.fetch_data_async()

if not isinstance(data, dict):
raise PyJWKClientError("The JWKS endpoint did not return a JSON object")

return PyJWKSet.from_dict(data)

def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
jwk_set = self.get_jwk_set(refresh)
signing_keys = [
Expand All @@ -91,6 +145,21 @@ def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:

return signing_keys

async def get_signing_keys_async(self, refresh: bool = False) -> List[PyJWK]:
if not has_async:
raise PyJWKAsyncDisabledError()
jwk_set = await self.get_jwk_set_async(refresh)
signing_keys = [
jwk_set_key
for jwk_set_key in jwk_set.keys
if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id
]

if not signing_keys:
raise PyJWKClientError("The JWKS endpoint did not contain any signing keys")

return signing_keys

def get_signing_key(self, kid: str) -> PyJWK:
signing_keys = self.get_signing_keys()
signing_key = self.match_kid(signing_keys, kid)
Expand All @@ -107,11 +176,36 @@ def get_signing_key(self, kid: str) -> PyJWK:

return signing_key

async def get_signing_key_async(self, kid: str) -> PyJWK:
if not has_async:
raise PyJWKAsyncDisabledError()
signing_keys = await self.get_signing_keys_async()
signing_key = self.match_kid(signing_keys, kid)

if not signing_key:
# If no matching signing key from the jwk set, refresh the jwk set and try again.
signing_keys = await self.get_signing_keys_async(refresh=True)
signing_key = self.match_kid(signing_keys, kid)

if not signing_key:
raise PyJWKClientError(
f'Unable to find a signing key that matches: "{kid}"'
)

return signing_key

def get_signing_key_from_jwt(self, token: str) -> PyJWK:
unverified = decode_token(token, options={"verify_signature": False})
header = unverified["header"]
return self.get_signing_key(header.get("kid"))

async def get_signing_key_from_jwt_async(self, token: str) -> PyJWK:
if not has_async:
raise PyJWKAsyncDisabledError()
unverified = decode_token(token, options={"verify_signature": False})
header = unverified["header"]
return await self.get_signing_key_async(header.get("kid"))

@staticmethod
def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]:
signing_key = None
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[tool.check-manifest]
ignore = [".tool-versions"]

[tool.coverage.run]
parallel = true
branch = true
Expand Down
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ docs =
zope.interface
crypto =
cryptography>=3.4.0
async =
httpx==0.27.0
async-lru==2.0.4
tests =
pytest>=6.0.0,<7.0.0
pytest-asyncio==0.20.3
coverage[toml]==5.0.4
dev =
sphinx
Expand Down
41 changes: 41 additions & 0 deletions tests/test_jwks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ def mocked_success_response(data):
yield urlopen_mock


@contextlib.asynccontextmanager
async def mocked_success_response_async(data):
with mock.patch.object(
jwt.jwks_client,
"AsyncClient",
) as async_client_mock:
response = mock.AsyncMock()
response.__aenter__ = mock.AsyncMock(return_value=response)
response.__aexit__ = mock.AsyncMock()
response.get = mock.AsyncMock(return_value=[json.dumps(data)])
async_client_mock.return_value = response
yield async_client_mock


@contextlib.contextmanager
def mocked_failed_response():
with mock.patch("urllib.request.urlopen") as urlopen_mock:
Expand Down Expand Up @@ -219,6 +233,33 @@ def test_get_signing_key_from_jwt(self):
"gty": "client-credentials",
}

@pytest.mark.asyncio
async def test_get_signing_key_from_jwt_async(self):
token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik5FRTFRVVJCT1RNNE16STVSa0ZETlRZeE9UVTFNRGcyT0Rnd1EwVXpNVGsxUWpZeVJrUkZRdyJ9.eyJpc3MiOiJodHRwczovL2Rldi04N2V2eDlydS5hdXRoMC5jb20vIiwic3ViIjoiYVc0Q2NhNzl4UmVMV1V6MGFFMkg2a0QwTzNjWEJWdENAY2xpZW50cyIsImF1ZCI6Imh0dHBzOi8vZXhwZW5zZXMtYXBpIiwiaWF0IjoxNTcyMDA2OTU0LCJleHAiOjE1NzIwMDY5NjQsImF6cCI6ImFXNENjYTc5eFJlTFdVejBhRTJINmtEME8zY1hCVnRDIiwiZ3R5IjoiY2xpZW50LWNyZWRlbnRpYWxzIn0.PUxE7xn52aTCohGiWoSdMBZGiYAHwE5FYie0Y1qUT68IHSTXwXVd6hn02HTah6epvHHVKA2FqcFZ4GGv5VTHEvYpeggiiZMgbxFrmTEY0csL6VNkX1eaJGcuehwQCRBKRLL3zKmA5IKGy5GeUnIbpPHLHDxr-GXvgFzsdsyWlVQvPX2xjeaQ217r2PtxDeqjlf66UYl6oY6AqNS8DH3iryCvIfCcybRZkc_hdy-6ZMoKT6Piijvk_aXdm7-QQqKJFHLuEqrVSOuBqqiNfVrG27QzAPuPOxvfXTVLXL2jek5meH6n-VWgrBdoMFH93QEszEDowDAEhQPHVs0xj7SIzA"
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

async with mocked_success_response_async(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client = PyJWKClient(url)
signing_key = await jwks_client.get_signing_key_from_jwt_async(token)

data = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
audience="https://expenses-api",
options={"verify_exp": False},
)

assert data == {
"iss": "https://dev-87evx9ru.auth0.com/",
"sub": "aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC@clients",
"aud": "https://expenses-api",
"iat": 1572006954,
"exp": 1572006964,
"azp": "aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC",
"gty": "client-credentials",
}

def test_get_jwk_set_caches_result(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

Expand Down
Loading