Skip to content

Commit

Permalink
migrate from python-jose to pyjwt
Browse files Browse the repository at this point in the history
  • Loading branch information
dvdalilue committed May 23, 2024
1 parent 562882b commit 60603fb
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 145 deletions.
4 changes: 2 additions & 2 deletions demo_project/api/api_v1/endpoints/graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any

import httpx
import jwt
from demo_project.api.dependencies import azure_scheme
from demo_project.core.config import settings
from fastapi import APIRouter, Depends, Request
from httpx import AsyncClient
from jose import jwt

router = APIRouter()

Expand Down Expand Up @@ -47,7 +47,7 @@ async def graph_world(request: Request) -> Any: # noqa: ANN401

# Return all the information to the end user
return (
{'claims': jwt.get_unverified_claims(token=request.state.user.access_token)}
{'claims': jwt.decode(request.state.user.access_token, options={'verify_signature': False})}
| {'obo_response': obo_response.json()}
| {'graph_response': graph}
)
83 changes: 55 additions & 28 deletions fastapi_azure_auth/auth.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
import inspect
import logging
from typing import Any, Awaitable, Callable, Dict, Literal, Optional
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Literal, Optional
from warnings import warn

import jwt
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2AuthorizationCodeBearer, SecurityScopes
from fastapi.security.base import SecurityBase
from jose import jwt
from jose.exceptions import ExpiredSignatureError, JWTClaimsError, JWTError
from jwt.exceptions import (
ExpiredSignatureError,
ImmatureSignatureError,
InvalidAlgorithmError,
InvalidAudienceError,
InvalidIssuedAtError,
InvalidIssuerError,
InvalidTokenError,
MissingRequiredClaimError,
)
from starlette.requests import Request

from fastapi_azure_auth.exceptions import InvalidAuth
from fastapi_azure_auth.openid_config import OpenIdConfig
from fastapi_azure_auth.user import User
from fastapi_azure_auth.utils import is_guest
from fastapi_azure_auth.utils import get_unverified_claims, get_unverified_header, is_guest

if TYPE_CHECKING:
from jwt.algorithms import AllowedPublicKeys
else:
AllowedPublicKeys = Any

log = logging.getLogger('fastapi_azure_auth')

Expand Down Expand Up @@ -147,9 +161,11 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
try:
access_token = await self.oauth(request=request)
try:
if access_token is None:
raise Exception('No access token provided')
# Extract header information of the token.
header: dict[str, str] = jwt.get_unverified_header(token=access_token) or {}
claims: dict[str, Any] = jwt.get_unverified_claims(token=access_token) or {}
header: dict[str, Any] = get_unverified_header(access_token)
claims: dict[str, Any] = get_unverified_claims(access_token)
except Exception as error:
log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True)
raise InvalidAuth(detail='Invalid token format') from error
Expand Down Expand Up @@ -180,48 +196,41 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
try:
if key := self.openid_config.signing_keys.get(header.get('kid', '')):
# We require and validate all fields in an Azure AD token
required_claims = ['exp', 'aud', 'iat', 'nbf', 'sub']
if self.validate_iss:
required_claims.append('iss')

options = {
'verify_signature': True,
'verify_aud': True,
'verify_iat': True,
'verify_exp': True,
'verify_nbf': True,
'verify_iss': self.validate_iss,
'verify_sub': True,
'verify_jti': True,
'verify_at_hash': True,
'require_aud': True,
'require_iat': True,
'require_exp': True,
'require_nbf': True,
'require_iss': self.validate_iss,
'require_sub': True,
'require_jti': False,
'require_at_hash': False,
'leeway': self.leeway,
'require': required_claims,
}
# Validate token
token = jwt.decode(
access_token,
key=key,
algorithms=['RS256'],
audience=self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}',
issuer=iss,
options=options,
)
token = self.validate(access_token=access_token, iss=iss, key=key, options=options)
# Attach the user to the request. Can be accessed through `request.state.user`
user: User = User(
**{**token, 'claims': token, 'access_token': access_token, 'is_guest': user_is_guest}
)
request.state.user = user
return user
except JWTClaimsError as error:
except (
InvalidAudienceError,
InvalidIssuerError,
InvalidIssuedAtError,
ImmatureSignatureError,
InvalidAlgorithmError,
MissingRequiredClaimError,
) as error:
log.info('Token contains invalid claims. %s', error)
raise InvalidAuth(detail='Token contains invalid claims') from error
except ExpiredSignatureError as error:
log.info('Token signature has expired. %s', error)
raise InvalidAuth(detail='Token signature has expired') from error
except JWTError as error:
except InvalidTokenError as error:
log.warning('Invalid token. Error: %s', error, exc_info=True)
raise InvalidAuth(detail='Unable to validate token') from error
except Exception as error:
Expand All @@ -235,6 +244,24 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
return None
raise

def validate(self, access_token: str, key: AllowedPublicKeys, iss: str, options: Dict[str, Any]) -> Dict[str, Any]:
"""
Validates the token using the provided key and options.
"""
alg = 'RS256'
aud = self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}'
return dict(
jwt.decode(
access_token,
key=key,
algorithms=[alg],
audience=aud,
issuer=iss,
leeway=self.leeway,
options=options,
)
)


class SingleTenantAzureAuthorizationCodeBearer(AzureAuthorizationCodeBearerBase):
def __init__(
Expand Down
16 changes: 10 additions & 6 deletions fastapi_azure_auth/openid_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes as KeyTypes
import jwt
from fastapi import HTTPException, status
from httpx import AsyncClient
from jose import jwk

if TYPE_CHECKING:
from jwt.algorithms import AllowedPublicKeys
else:
AllowedPublicKeys = Any

log = logging.getLogger('fastapi_azure_auth')

Expand All @@ -27,7 +31,7 @@ def __init__(
self.config_url = config_url

self.authorization_endpoint: str
self.signing_keys: dict[str, KeyTypes]
self.signing_keys: dict[str, AllowedPublicKeys]
self.token_endpoint: str
self.issuer: str

Expand Down Expand Up @@ -98,6 +102,6 @@ def _load_keys(self, keys: List[Dict[str, Any]]) -> None:
for key in keys:
if key.get('use') == 'sig': # Only care about keys that are used for signatures, not encryption
log.debug('Loading public key from certificate: %s', key)
cert_obj = jwk.construct(key, 'RS256')
cert_obj = jwt.PyJWK(key, 'RS256')
if kid := key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it.
self.signing_keys[kid] = cert_obj
self.signing_keys[kid] = cert_obj.key
20 changes: 20 additions & 0 deletions fastapi_azure_auth/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict

import jwt


def is_guest(claims: Dict[str, Any]) -> bool:
"""
Expand All @@ -12,3 +14,21 @@ def is_guest(claims: Dict[str, Any]) -> bool:
claims_iss: str = claims.get('iss', '')
idp: str = claims.get('idp', claims_iss)
return idp != claims_iss


def get_unverified_header(access_token: str | None) -> Dict[str, Any]:
"""
Get header from the access token without verifying the signature
"""
if access_token is None:
return {}
return dict(jwt.get_unverified_header(access_token))


def get_unverified_claims(access_token: str | None) -> Dict[str, Any]:
"""
Get claims from the access token without verifying the signature
"""
if access_token is None:
return {}
return dict(jwt.decode(access_token, options={'verify_signature': False}))
85 changes: 19 additions & 66 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ classifiers = [
python = "^3.8"
fastapi = ">0.68.0"
cryptography = ">=40.0.1"
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
httpx = ">0.18.2"
pyjwt = "^2.8.0"


[tool.poetry.group.dev.dependencies]
Expand Down
3 changes: 2 additions & 1 deletion tests/multi_tenant/test_multi_tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

from fastapi_azure_auth import MultiTenantAzureAuthorizationCodeBearer
from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase
from fastapi_azure_auth.exceptions import InvalidAuth


Expand Down Expand Up @@ -283,7 +284,7 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys):

@pytest.mark.anyio
async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker):
mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol'))
mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol'))
async with AsyncClient(
app=app,
base_url='http://test',
Expand Down
Loading

0 comments on commit 60603fb

Please sign in to comment.