Skip to content

Commit

Permalink
fix: Validate Azure JWTs using authlib (#2112)
Browse files Browse the repository at this point in the history
* Use joserfc to decode Azure jwt

* Apply black formatting

* Add joserfc to requirements.txt

* Remove unnecessary import statement

* Fix import order

* Switch to authlib to stay compatible with Python 3.7

* Fix import order

* Move Microsoft key set URL to constants

* Make logging less verbose

* Add unittest

* Remove authlib from install_requires

* Fix formatting

* Update unittest

* Fix import order

* Move authlib to requirements-extra.txt
  • Loading branch information
wolfdn authored Oct 6, 2023
1 parent e4d613a commit 57f4400
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 36 deletions.
6 changes: 6 additions & 0 deletions flask_appbuilder/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,9 @@
API_ADD_TITLE_RIS_KEY = "add_title"
API_EDIT_TITLE_RIS_KEY = "edit_title"
API_SHOW_TITLE_RIS_KEY = "show_title"

# -----------------------------------
# OAuth Provider Constants
# -----------------------------------

MICROSOFT_KEY_SET_URL = "https://login.microsoftonline.com/common/discovery/keys"
47 changes: 12 additions & 35 deletions flask_appbuilder/security/manager.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import base64
import datetime
import json
import logging
import re
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from authlib.jose import JsonWebKey, jwt
from flask import Flask, g, session, url_for
from flask_babel import lazy_gettext as _
from flask_jwt_extended import current_user as current_user_jwt
from flask_jwt_extended import JWTManager
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_login import current_user, LoginManager
import requests
from werkzeug.security import check_password_hash, generate_password_hash

from .api import SecurityApi
Expand Down Expand Up @@ -54,6 +55,7 @@
LOGMSG_WAR_SEC_LOGIN_FAILED,
LOGMSG_WAR_SEC_NO_USER,
LOGMSG_WAR_SEC_NOLDAP_OBJ,
MICROSOFT_KEY_SET_URL,
PERMISSION_PREFIX,
)

Expand Down Expand Up @@ -632,11 +634,9 @@ def get_oauth_user_info(self, provider, resp):
# https://docs.microsoft.com/en-us/azure/active-directory/develop/
# active-directory-protocols-oauth-code
if provider == "azure":
log.debug("Azure response received : %s", resp)
id_token = resp["id_token"]
log.debug(str(id_token))
me = self._azure_jwt_token_parse(id_token)
log.debug("Parse JWT token : %s", me)
log.debug("Azure response received:\n%s", json.dumps(resp, indent=4))
me = self._decode_and_validate_azure_jwt(resp["id_token"])
log.debug("Decoded JWT:\n%s", json.dumps(me, indent=4))
return {
"name": me.get("name", ""),
"email": me["upn"],
Expand Down Expand Up @@ -683,36 +683,13 @@ def get_oauth_user_info(self, provider, resp):
else:
return {}

def _azure_parse_jwt(self, id_token):
jwt_token_parts = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$"
matches = re.search(jwt_token_parts, id_token)
if not matches or len(matches.groups()) < 3:
log.error("Unable to parse token.")
return {}
return {
"header": matches.group(1),
"Payload": matches.group(2),
"Sig": matches.group(3),
}

def _azure_jwt_token_parse(self, id_token):
jwt_split_token = self._azure_parse_jwt(id_token)
if not jwt_split_token:
return

jwt_payload = jwt_split_token["Payload"]
# Prepare for base64 decoding
payload_b64_string = jwt_payload
payload_b64_string += "=" * (4 - ((len(jwt_payload) % 4)))
decoded_payload = base64.urlsafe_b64decode(payload_b64_string.encode("ascii"))

if not decoded_payload:
log.error("Payload of id_token could not be base64 url decoded.")
return

jwt_decoded_payload = json.loads(decoded_payload.decode("utf-8"))
def _decode_and_validate_azure_jwt(self, id_token):
keyset = JsonWebKey.import_key_set(requests.get(MICROSOFT_KEY_SET_URL).json())
claims = jwt.decode(id_token, keyset)
claims.validate()
log.debug("Decoded JWT:\n%s", json.dumps(claims, indent=4))

return jwt_decoded_payload
return claims

def register_views(self):
if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True):
Expand Down
2 changes: 1 addition & 1 deletion requirements-extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ mysqlclient==2.0.1
psycopg2-binary==2.9.6
pyodbc==4.0.35
requests==2.26.0
Authlib==0.15.4
Authlib==1.2.1
python-ldap==3.3.1
flask-openid==1.3.0
13 changes: 13 additions & 0 deletions tests/security/test_base_security_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import datetime
import json
import unittest
from unittest.mock import MagicMock, patch

from flask_appbuilder.security.manager import BaseSecurityManager
from flask_appbuilder.security.manager import JsonWebKey, jwt

JWTClaimsMock = MagicMock()


@patch.object(BaseSecurityManager, "update_user")
Expand Down Expand Up @@ -67,3 +71,12 @@ def test_subsequent_unsuccessful_auth(self, mock1, mock2):
self.assertEqual(user_mock.fail_login_count, 10)
self.assertEqual(user_mock.last_login, None)
self.assertTrue(bsm.update_user.called_once)

@patch.object(JsonWebKey, "import_key_set", MagicMock())
@patch.object(jwt, "decode", MagicMock(return_value=JWTClaimsMock))
@patch.object(json, "dumps", MagicMock(return_value="DecodedExampleAzureJWT"))
def test_azure_jwt_validated(self, mock1, mock2):
bsm = BaseSecurityManager()

bsm._decode_and_validate_azure_jwt("ExampleAzureJWT")
JWTClaimsMock.validate.assert_called()

0 comments on commit 57f4400

Please sign in to comment.