Skip to content

Commit

Permalink
refactor(socialaccount): Extract JWT verification
Browse files Browse the repository at this point in the history
  • Loading branch information
pennersr committed Feb 8, 2024
1 parent 9c08094 commit 701bcc6
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 98 deletions.
Empty file.
86 changes: 86 additions & 0 deletions allauth/socialaccount/internal/jwtkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json

import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate

from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.client import OAuth2Error


def lookup_kid_pem_x509_certificate(keys_data, kid):
"""
Looks up the key given keys data of the form:
{"<kid>": "-----BEGIN CERTIFICATE-----\nCERTIFICATE"}
"""
key = keys_data.get(kid)
if key:
public_key = load_pem_x509_certificate(
key.encode("utf8"), default_backend()
).public_key()
return public_key


def lookup_kid_jwk(keys_data, kid):
"""
Looks up the key given keys data of the form:
{
"keys": [
{
"kty": "RSA",
"kid": "W6WcOKB",
"use": "sig",
"alg": "RS256",
"n": "2Zc5d0-zk....",
"e": "AQAB"
}]
}
"""
for d in keys_data["keys"]:
if d["kid"] == kid:
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(d))
return public_key


def fetch_key(credential, keys_url, lookup):
header = jwt.get_unverified_header(credential)
# {'alg': 'RS256', 'kid': '0ad1fec78504f447bae65bcf5afaedb65eec9e81', 'typ': 'JWT'}
kid = header["kid"]
alg = header["alg"]
response = get_adapter().get_requests_session().get(keys_url)
response.raise_for_status()
keys_data = response.json()
key = lookup(keys_data, kid)
if not key:
raise OAuth2Error(f"Invalid 'kid': '{kid}'")
return alg, key


def verify_and_decode(
*, credential, keys_url, issuer, audience, lookup_kid, verify_signature=True
):
try:
if verify_signature:
alg, key = fetch_key(credential, keys_url, lookup_kid)
algorithms = [alg]
else:
key = ""
algorithms = None
data = jwt.decode(
credential,
key=key,
options={
"verify_signature": verify_signature,
"verify_iss": True,
"verify_aud": True,
"verify_exp": True,
},
issuer=issuer,
audience=audience,
algorithms=algorithms,
)
return data
except jwt.PyJWTError as e:
raise OAuth2Error("Invalid id_token") from e
48 changes: 9 additions & 39 deletions allauth/socialaccount/providers/apple/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from django.utils.http import urlencode
from django.views.decorators.csrf import csrf_exempt

import jwt

from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.internal import jwtkit
from allauth.socialaccount.models import SocialToken
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
Expand All @@ -31,49 +29,21 @@ class AppleOAuth2Adapter(OAuth2Adapter):
authorize_url = "https://appleid.apple.com/auth/authorize"
public_key_url = "https://appleid.apple.com/auth/keys"

def _get_apple_public_key(self, kid):
response = get_adapter().get_requests_session().get(self.public_key_url)
response.raise_for_status()
try:
data = response.json()
except json.JSONDecodeError as e:
raise OAuth2Error("Error retrieving apple public key.") from e

for d in data["keys"]:
if d["kid"] == kid:
return d

def get_public_key(self, id_token):
"""
Get the public key which matches the `kid` in the id_token header.
"""
kid = jwt.get_unverified_header(id_token)["kid"]
apple_public_key = self._get_apple_public_key(kid=kid)

public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(apple_public_key))
return public_key

def get_client_id(self, provider):
app = get_adapter().get_app(request=None, provider=self.provider_id)
return [aud.strip() for aud in app.client_id.split(",")]

def get_verified_identity_data(self, id_token):
provider = self.get_provider()
allowed_auds = self.get_client_id(provider)

try:
public_key = self.get_public_key(id_token)
identity_data = jwt.decode(
id_token,
public_key,
algorithms=["RS256"],
audience=allowed_auds,
issuer="https://appleid.apple.com",
)
return identity_data

except jwt.PyJWTError as e:
raise OAuth2Error("Invalid id_token") from e
data = jwtkit.verify_and_decode(
credential=id_token,
keys_url=self.public_key_url,
issuer="https://appleid.apple.com",
audience=allowed_auds,
lookup_kid=jwtkit.lookup_kid_jwk,
)
return data

def parse_token(self, data):
token = SocialToken(
Expand Down
6 changes: 3 additions & 3 deletions allauth/socialaccount/providers/google/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,14 @@ class AppInSettingsTests(GoogleTests):
def test_login_by_token(db, client, settings_with_google_provider):
client.cookies.load({"g_csrf_token": "csrf"})
with patch(
"allauth.socialaccount.providers.google.views.jwt.get_unverified_header"
"allauth.socialaccount.internal.jwtkit.jwt.get_unverified_header"
) as g_u_h:
with mocked_response({"dummykid": "-----BEGIN CERTIFICATE-----"}):
with patch(
"allauth.socialaccount.providers.google.views.load_pem_x509_certificate"
"allauth.socialaccount.internal.jwtkit.load_pem_x509_certificate"
) as load_pem:
with patch(
"allauth.socialaccount.providers.google.views.jwt.decode"
"allauth.socialaccount.internal.jwtkit.jwt.decode"
) as decode:
decode.return_value = {
"iss": "https://accounts.google.com",
Expand Down
78 changes: 22 additions & 56 deletions allauth/socialaccount/providers/google/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@
from django.views.decorators.csrf import csrf_exempt
from django.views.generic import View

import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate

from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.helpers import (
complete_social_login,
render_authentication_error,
)
from allauth.socialaccount.internal import jwtkit
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
Expand Down Expand Up @@ -61,6 +58,17 @@
)


def _verify_and_decode(app, credential, verify_signature=True):
return jwtkit.verify_and_decode(
credential=credential,
keys_url=CERTS_URL,
issuer=ID_TOKEN_ISSUER,
audience=app.client_id,
lookup_kid=jwtkit.lookup_kid_pem_x509_certificate,
verify_signature=verify_signature,
)


class GoogleOAuth2Adapter(OAuth2Adapter):
provider_id = GoogleProvider.id
access_token_url = ACCESS_TOKEN_URL
Expand All @@ -85,27 +93,14 @@ def complete_login(self, request, app, token, response, **kwargs):
return login

def _decode_id_token(self, app, id_token):
try:
data = jwt.decode(
id_token,
# Since the token was received by direct communication
# protected by TLS between this library and Google, we
# are allowed to skip checking the token signature
# according to the OpenID Connect Core 1.0
# specification.
# https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
options={
"verify_signature": False,
"verify_iss": True,
"verify_aud": True,
"verify_exp": True,
},
issuer=self.id_token_issuer,
audience=app.client_id,
)
except jwt.PyJWTError as e:
raise OAuth2Error("Invalid id_token") from e
return data
"""
Since the token was received by direct communication protected by
TLS between this library and Google, we are allowed to skip checking the
token signature according to the OpenID Connect Core 1.0 specification.
https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
"""
return _verify_and_decode(app, id_token, verify_signature=False)

def _fetch_user_info(self, access_token):
resp = (
Expand All @@ -132,9 +127,9 @@ def dispatch(self, request):
try:
return super().dispatch(request)
except (
OAuth2Error,
requests.RequestException,
PermissionDenied,
jwt.PyJWTError,
) as exc:
return render_authentication_error(request, self.provider, exception=exc)

Expand All @@ -147,20 +142,7 @@ def post(self, request, *args, **kwargs):
self.check_csrf(request)

credential = request.POST.get("credential")
alg, key = self.get_key(credential)
identity_data = jwt.decode(
credential,
key,
options={
"verify_signature": True,
"verify_iss": True,
"verify_aud": True,
"verify_exp": True,
},
issuer=ID_TOKEN_ISSUER,
audience=self.provider.app.client_id,
algorithms=[alg],
)
identity_data = _verify_and_decode(app=self.provider.app, credential=credential)
login = self.provider.sociallogin_from_response(request, identity_data)
return complete_social_login(request, login)

Expand All @@ -174,21 +156,5 @@ def check_csrf(self, request):
if csrf_token_cookie != csrf_token_body:
raise PermissionDenied("Failed to verify double submit cookie.")

def get_key(self, credential):
header = jwt.get_unverified_header(credential)
# {'alg': 'RS256', 'kid': '0ad1fec78504f447bae65bcf5afaedb65eec9e81', 'typ': 'JWT'}
kid = header["kid"]
alg = header["alg"]
response = get_adapter().get_requests_session().get(CERTS_URL)
response.raise_for_status()
jwks = response.json()
key = jwks.get(kid)
if not key:
raise PermissionDenied("invalid 'kid'")
key = load_pem_x509_certificate(
key.encode("utf8"), default_backend()
).public_key()
return alg, key


login_by_token = csrf_exempt(LoginByTokenView.as_view())

0 comments on commit 701bcc6

Please sign in to comment.