Skip to content

Commit

Permalink
Replace python-jose with pyjwt (#1875)
Browse files Browse the repository at this point in the history
* Replace python-jose with pyjwt.

* Replace non-existent get_unverified_claims function

* Change Exception to handle JWT-specific errors

* Convert the public key to PEM format

* Add pem format tests

* Another test, plus autherror fixes

---------

Co-authored-by: Pamela Fox <pamelafox@microsoft.com>
  • Loading branch information
blutril and pamelafox authored Aug 7, 2024
1 parent 27816c1 commit a8b1202
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 46 deletions.
63 changes: 35 additions & 28 deletions app/backend/core/authentication.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# Refactored from https://github.com/Azure-Samples/ms-identity-python-on-behalf-of

import base64
import json
import logging
from typing import Any, Optional

import aiohttp
import jwt
from azure.search.documents.aio import SearchClient
from azure.search.documents.indexes.models import SearchIndex
from jose import jwt
from jose.exceptions import ExpiredSignatureError, JWTClaimsError
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from msal import ConfidentialClientApplication
from msal.token_cache import TokenCache
from tenacity import (
Expand Down Expand Up @@ -282,6 +284,24 @@ async def check_path_auth(self, path: str, auth_claims: dict[str, Any], search_c

return allowed

async def create_pem_format(self, jwks, token):
unverified_header = jwt.get_unverified_header(token)
for key in jwks["keys"]:
if key["kid"] == unverified_header["kid"]:
# Construct the RSA public key
public_numbers = rsa.RSAPublicNumbers(
e=int.from_bytes(base64.urlsafe_b64decode(key["e"] + "=="), byteorder="big"),
n=int.from_bytes(base64.urlsafe_b64decode(key["n"] + "=="), byteorder="big"),
)
public_key = public_numbers.public_key()

# Convert to PEM format
pem_key = public_key.public_bytes(
encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo
)
rsa_key = pem_key
return rsa_key

# See https://github.com/Azure-Samples/ms-identity-python-on-behalf-of/blob/939be02b11f1604814532fdacc2c2eccd198b755/FlaskAPI/helpers/authorization.py#L44
async def validate_access_token(self, token: str):
"""
Expand All @@ -304,51 +324,38 @@ async def validate_access_token(self, token: str):
jwks = await resp.json()

if not jwks or "keys" not in jwks:
raise AuthError({"code": "invalid_keys", "description": "Unable to get keys to validate auth token."}, 401)
raise AuthError("Unable to get keys to validate auth token.", 401)

rsa_key = None
issuer = None
audience = None
try:
unverified_header = jwt.get_unverified_header(token)
unverified_claims = jwt.get_unverified_claims(token)
unverified_claims = jwt.decode(token, options={"verify_signature": False})
issuer = unverified_claims.get("iss")
audience = unverified_claims.get("aud")
for key in jwks["keys"]:
if key["kid"] == unverified_header["kid"]:
rsa_key = {"kty": key["kty"], "kid": key["kid"], "use": key["use"], "n": key["n"], "e": key["e"]}
break
except Exception as exc:
raise AuthError(
{"code": "invalid_header", "description": "Unable to parse authorization token."}, 401
) from exc
rsa_key = await self.create_pem_format(jwks, token)
except jwt.PyJWTError as exc:
raise AuthError("Unable to parse authorization token.", 401) from exc
if not rsa_key:
raise AuthError({"code": "invalid_header", "description": "Unable to find appropriate key"}, 401)
raise AuthError("Unable to find appropriate key", 401)

if issuer not in self.valid_issuers:
raise AuthError(
{"code": "invalid_header", "description": f"Issuer {issuer} not in {','.join(self.valid_issuers)}"}, 401
)
raise AuthError(f"Issuer {issuer} not in {','.join(self.valid_issuers)}", 401)

if audience not in self.valid_audiences:
raise AuthError(
{
"code": "invalid_header",
"description": f"Audience {audience} not in {','.join(self.valid_audiences)}",
},
f"Audience {audience} not in {','.join(self.valid_audiences)}",
401,
)

try:
jwt.decode(token, rsa_key, algorithms=["RS256"], audience=audience, issuer=issuer)
except ExpiredSignatureError as jwt_expired_exc:
raise AuthError({"code": "token_expired", "description": "token is expired"}, 401) from jwt_expired_exc
except JWTClaimsError as jwt_claims_exc:
except jwt.ExpiredSignatureError as jwt_expired_exc:
raise AuthError("Token is expired", 401) from jwt_expired_exc
except (jwt.InvalidAudienceError, jwt.InvalidIssuerError) as jwt_claims_exc:
raise AuthError(
{"code": "invalid_claims", "description": "incorrect claims," "please check the audience and issuer"},
"Incorrect claims: please check the audience and issuer",
401,
) from jwt_claims_exc
except Exception as exc:
raise AuthError(
{"code": "invalid_header", "description": "Unable to parse authorization token."}, 401
) from exc
raise AuthError("Unable to parse authorization token.", 401) from exc
3 changes: 1 addition & 2 deletions app/backend/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ opentelemetry-instrumentation-aiohttp-client
opentelemetry-instrumentation-openai
msal
cryptography
python-jose[cryptography]
types-python-jose
PyJWT
Pillow
types-Pillow
pypdf
Expand Down
15 changes: 0 additions & 15 deletions app/backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,10 @@ cryptography==43.0.0
# azure-storage-blob
# msal
# pyjwt
# python-jose
deprecated==1.2.14
# via opentelemetry-api
distro==1.9.0
# via openai
ecdsa==0.19.0
# via python-jose
fixedint==0.1.6
# via azure-monitor-opentelemetry-exporter
flask==3.0.3
Expand Down Expand Up @@ -324,10 +321,6 @@ priority==2.0.0
# via hypercorn
psutil==5.9.8
# via azure-monitor-opentelemetry-exporter
pyasn1==0.6.0
# via
# python-jose
# rsa
pycparser==2.22
# via cffi
pydantic==2.8.2
Expand All @@ -349,8 +342,6 @@ python-dateutil==2.9.0.post0
# microsoft-kiota-serialization-text
# pendulum
# time-machine
python-jose[cryptography]==3.3.0
# via -r requirements.in
quart==0.19.6
# via
# -r requirements.in
Expand All @@ -368,8 +359,6 @@ requests==2.32.3
# tiktoken
requests-oauthlib==2.0.0
# via msrest
rsa==4.9
# via python-jose
six==1.16.0
# via
# azure-core
Expand Down Expand Up @@ -402,10 +391,6 @@ types-html5lib==1.1.11.20240228
# via types-beautifulsoup4
types-pillow==10.2.0.20240520
# via -r requirements.in
types-pyasn1==0.6.0.20240402
# via types-python-jose
types-python-jose==3.3.4.20240106
# via -r requirements.in
typing-extensions==4.12.2
# via
# azure-ai-documentintelligence
Expand Down
172 changes: 171 additions & 1 deletion tests/test_authenticationhelper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import base64
import json
import re
from datetime import datetime, timedelta

import aiohttp
import jwt
import pytest
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.aio import SearchClient
from azure.search.documents.indexes.models import SearchField, SearchIndex
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa

from core.authentication import AuthenticationHelper, AuthError

from .mocks import MockAsyncPageIterator
from .mocks import MockAsyncPageIterator, MockResponse

MockSearchIndex = SearchIndex(
name="test",
Expand Down Expand Up @@ -40,6 +47,36 @@ def create_search_client():
return SearchClient(endpoint="", index_name="", credential=AzureKeyCredential(""))


def create_mock_jwt(kid="mock_kid", oid="OID_X"):
# Create a payload with necessary claims
payload = {
"iss": "https://login.microsoftonline.com/TENANT_ID/v2.0",
"sub": "AaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaA",
"aud": "SERVER_APP",
"exp": int((datetime.utcnow() + timedelta(hours=1)).timestamp()),
"iat": int(datetime.utcnow().timestamp()),
"nbf": int(datetime.utcnow().timestamp()),
"name": "John Doe",
"oid": oid,
"preferred_username": "john.doe@example.com",
"rh": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA.",
"tid": "22222222-2222-2222-2222-222222222222",
"uti": "AbCdEfGhIjKlMnOp-ABCDEFG",
"ver": "2.0",
}

# Create a header
header = {"kid": kid, "alg": "RS256", "typ": "JWT"}

# Create a mock private key (for signing)
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

# Create the JWT
token = jwt.encode(payload, private_key, algorithm="RS256", headers=header)

return token, private_key.public_key(), payload


@pytest.mark.asyncio
async def test_get_auth_claims_success(mock_confidential_client_success, mock_validate_token_success):
helper = create_authentication_helper()
Expand Down Expand Up @@ -479,3 +516,136 @@ async def mock_search(self, *args, **kwargs):
)
assert filter is None
assert called_search is False


@pytest.mark.asyncio
async def test_create_pem_format(mock_confidential_client_success, mock_validate_token_success):
helper = create_authentication_helper()
mock_token, public_key, payload = create_mock_jwt(oid="OID_X")
_, other_public_key, _ = create_mock_jwt(oid="OID_Y")
mock_jwks = {
"keys": [
# Include a key with a different KID to ensure the correct key is selected
{
"kty": "RSA",
"kid": "other_mock_kid",
"use": "sig",
"n": base64.urlsafe_b64encode(
other_public_key.public_numbers().n.to_bytes(
(other_public_key.public_numbers().n.bit_length() + 7) // 8, byteorder="big"
)
)
.decode("utf-8")
.rstrip("="),
"e": base64.urlsafe_b64encode(
other_public_key.public_numbers().e.to_bytes(
(other_public_key.public_numbers().e.bit_length() + 7) // 8, byteorder="big"
)
)
.decode("utf-8")
.rstrip("="),
},
{
"kty": "RSA",
"kid": "mock_kid",
"use": "sig",
"n": base64.urlsafe_b64encode(
public_key.public_numbers().n.to_bytes(
(public_key.public_numbers().n.bit_length() + 7) // 8, byteorder="big"
)
)
.decode("utf-8")
.rstrip("="),
"e": base64.urlsafe_b64encode(
public_key.public_numbers().e.to_bytes(
(public_key.public_numbers().e.bit_length() + 7) // 8, byteorder="big"
)
)
.decode("utf-8")
.rstrip("="),
},
]
}

pem_key = await helper.create_pem_format(mock_jwks, mock_token)

# Assert that the result is bytes
assert isinstance(pem_key, bytes), "create_pem_format should return bytes"

# Convert bytes to string for regex matching
pem_str = pem_key.decode("utf-8")

# Assert that the key starts and ends with the correct markers
assert pem_str.startswith("-----BEGIN PUBLIC KEY-----"), "PEM key should start with the correct marker"
assert pem_str.endswith("-----END PUBLIC KEY-----\n"), "PEM key should end with the correct marker"

# Assert that the format matches the structure of a PEM key
pem_regex = r"^-----BEGIN PUBLIC KEY-----\n([A-Za-z0-9+/\n]+={0,2})\n-----END PUBLIC KEY-----\n$"
assert re.match(pem_regex, pem_str), "PEM key format is incorrect"

# Verify that the key can be used to decode the token
try:
decoded = jwt.decode(
mock_token, key=pem_key, algorithms=["RS256"], audience=payload["aud"], issuer=payload["iss"]
)
assert decoded["oid"] == payload["oid"], "Decoded token should contain correct OID"
except Exception as e:
pytest.fail(f"jwt.decode raised an unexpected exception: {str(e)}")

# Try to load the key using cryptography library to ensure it's a valid PEM format
try:
loaded_public_key = serialization.load_pem_public_key(pem_key)
assert isinstance(loaded_public_key, rsa.RSAPublicKey), "Loaded key should be an RSA public key"
except Exception as e:
pytest.fail(f"Failed to load PEM key: {str(e)}")


@pytest.mark.asyncio
async def test_validate_access_token(monkeypatch, mock_confidential_client_success):
mock_token, public_key, payload = create_mock_jwt(oid="OID_X")

def mock_get(*args, **kwargs):
return MockResponse(
status=200,
text=json.dumps(
{
"keys": [
{
"kty": "RSA",
"use": "sig",
"kid": "23nt",
"x5t": "23nt",
"n": "hu2SJ",
"e": "AQAB",
"x5c": ["MIIC/jCC"],
"issuer": "https://login.microsoftonline.com/TENANT_ID/v2.0",
},
{
"kty": "RSA",
"use": "sig",
"kid": "MGLq",
"x5t": "MGLq",
"n": "yfNcG8",
"e": "AQAB",
"x5c": ["MIIC/jCC"],
"issuer": "https://login.microsoftonline.com/TENANT_ID/v2.0",
},
]
}
),
)

monkeypatch.setattr(aiohttp.ClientSession, "get", mock_get)

def mock_decode(*args, **kwargs):
return payload

monkeypatch.setattr(jwt, "decode", mock_decode)

async def mock_create_pem_format(*args, **kwargs):
return public_key

monkeypatch.setattr(AuthenticationHelper, "create_pem_format", mock_create_pem_format)

helper = create_authentication_helper()
await helper.validate_access_token(mock_token)

0 comments on commit a8b1202

Please sign in to comment.