-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Token Validations #1682
Added Token Validations #1682
Changes from 28 commits
155db7a
b70b55c
88c03e5
5d1c2c9
02b5d36
7117d55
cba99e9
09caacb
0e77776
e6fcf8f
58a5440
c04760a
b5c0c08
25334bc
fca14f2
64fa2f1
d416a55
913c1ce
a8cfe08
647b689
eddcb86
3b4418d
f30dd4f
320f652
ff910c7
9b9cbad
451b0d4
0e329a3
5a0893b
8b8eb1a
26bd79c
86db3e8
8db81a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,12 +24,14 @@ | |
] | ||
ENGINE = get_engine(envname=ENVNAME) | ||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS', '*') | ||
AWS_REGION = os.getenv('AWS_REGION') | ||
|
||
|
||
def redact_creds(event): | ||
if 'headers' in event and 'Authorization' in event['headers']: | ||
if event.get('headers', {}).get('Authorization'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can anyone please tell why are we redacting these ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. user creds information does not need to be logged and is no longer relevant for the remaining request lifecycle - opting to redact that info There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of redacting , can we just extract useful information and then pass it onto our lambda ? |
||
event['headers']['Authorization'] = 'XXXXXXXXXXXX' | ||
if 'multiValueHeaders' in event and 'Authorization' in event['multiValueHeaders']: | ||
|
||
if event.get('multiValueHeaders', {}).get('Authorization'): | ||
event['multiValueHeaders']['Authorization'] = 'XXXXXXXXXXXX' | ||
return event | ||
|
||
|
@@ -115,7 +117,7 @@ def check_reauth(query, auth_time, username): | |
# Determine if there are any Operations that Require ReAuth From SSM Parameter | ||
try: | ||
reauth_apis = ParameterStoreManager.get_parameter_value( | ||
region=os.getenv('AWS_REGION', 'eu-west-1'), parameter_path=f'/dataall/{ENVNAME}/reauth/apis' | ||
region=AWS_REGION, parameter_path=f'/dataall/{ENVNAME}/reauth/apis' | ||
).split(',') | ||
except Exception: | ||
log.info('No ReAuth APIs Found in SSM') | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import logging | ||
import os | ||
import json | ||
|
||
from auth_services import AuthServices | ||
from jwt_services import JWTServices | ||
|
@@ -16,21 +17,32 @@ | |
Custom Lambda Authorizer is attached to the API Gateway. Check the deploy/stacks/lambda_api.py for more details on deployment | ||
""" | ||
|
||
OPENID_CONFIG_PATH = os.path.join(os.environ.get('custom_auth_url', ''), '.well-known', 'openid-configuration') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure that is a good call out - just made that change to |
||
jwt_service = JWTServices(OPENID_CONFIG_PATH) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I would capitalize global variables There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wasn't sure for global class instance - capitalizing now |
||
|
||
|
||
def lambda_handler(incoming_event, context): | ||
# Get the Token which is sent in the Authorization Header | ||
logger.debug(incoming_event) | ||
petrkalos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auth_token = incoming_event['headers']['Authorization'] | ||
if not auth_token: | ||
raise Exception('Unauthorized . Token not found') | ||
raise Exception('Unauthorized. Missing identity or access JWT') | ||
|
||
verified_claims = JWTServices.validate_jwt_token(auth_token) | ||
logger.debug(verified_claims) | ||
# Validate User is Active with Proper Access Token | ||
user_info = jwt_service.validate_access_token(auth_token) | ||
|
||
# Validate JWT | ||
verified_claims = jwt_service.validate_jwt_token(auth_token[7:]) | ||
if not verified_claims: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the method |
||
raise Exception('Unauthorized. Token is not valid') | ||
logger.debug(verified_claims) | ||
|
||
# Generate Allow Policy w/ Context | ||
effect = 'Allow' | ||
verified_claims.update(user_info) | ||
policy = AuthServices.generate_policy(verified_claims, effect, incoming_event['methodArn']) | ||
logger.debug('Generated policy is ', policy) | ||
logger.debug(f'Generated policy is {json.dumps(policy)}') | ||
print(f'Generated policy is {json.dumps(policy)}') | ||
return policy | ||
|
||
|
||
|
@@ -39,12 +51,13 @@ def lambda_handler(incoming_event, context): | |
# AWS Lambda and any other local environments | ||
if __name__ == '__main__': | ||
# for testing locally you can enter the JWT ID Token here | ||
token = '' | ||
# | ||
access_token = 'Bearer eyJraWQiOiJtYTJ6SUxrbVMtQW1qZzZwVGtqZjhkN3JxY1FaNWE2eWtLS3dGQkFZckJBIiwidHlwIjoiYXBwbGljYXRpb25cL29rdGEtaW50ZXJuYWwtYXQrand0IiwiYWxnIjoiUlMyNTYifQ.eyJ2ZXIiOjEsImp0aSI6IkFULm9DSERGSHpVdGFUeTFDQXBENWF3amZRMEdyUzNPcEpyNE93czdnM3JKUXciLCJpc3MiOiJodHRwczovL2Rldi0zNzAxMTAxMC5va3RhLmNvbSIsImF1ZCI6Imh0dHBzOi8vZGV2LTM3MDExMDEwLm9rdGEuY29tIiwic3ViIjoibm9haHBhaWdAYW1hem9uLmNvbSIsImlhdCI6MTczMDkyNTQ1MiwiZXhwIjoxNzMwOTI5MDUyLCJjaWQiOiIwb2FkcndpcmVxcldoanFYaTVkNyIsInVpZCI6IjAwdWRydTNtNTZWS3hnWEtKNWQ3Iiwic2NwIjpbIm9wZW5pZCIsImVtYWlsIiwicHJvZmlsZSJdLCJhdXRoX3RpbWUiOjE3MzA5MjU0NTF9.uFZ123U7nbu6rN0L9WB2EZQTEZCnMcYOV_6uS4XRb8TAREcat-Kk88rLXONLwNWSaLaqGXOsr1tC1bd9FdTXyWG9WmVkihep8un_tmy1V410vEBtzXes6nqsr4-QZsx7csrWWtDetm4T7Smtl621z4isL8ePdYtkWe_2SELJjiOpr8qQ8pXMVEwMY8kiu-VuZHUXNnFGvrIRtNytsNzFVunbQxOX58uCq_J5eU7MRbj0tBAYqLXgXrj1iskb17uGHL4IqIWl1Te6qk05bLMZ9RrySEpyuCmYDPIgFpUZNiewLUNgPTNb4I8wrKycTpNfEEhTiLNxjo7QA5y2stTrFg' | ||
account_id = '' | ||
api_gw_id = '' | ||
event = { | ||
'headers': {'Authorization': access_token}, | ||
'type': 'TOKEN', | ||
'Authorization': token, | ||
'methodArn': f'arn:aws:execute-api:us-east-1:{account_id}:{api_gw_id}/prod/POST/graphql/api', | ||
} | ||
lambda_handler(event, None) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,101 +1,79 @@ | ||
import os | ||
|
||
import requests | ||
from jose import jwk | ||
from jose.jwt import get_unverified_header, decode, ExpiredSignatureError, JWTError | ||
import jwt | ||
|
||
import logging | ||
|
||
logger = logging.getLogger() | ||
logger.setLevel(os.environ.get('LOG_LEVEL', 'INFO')) | ||
|
||
# Configs required to fetch public keys from JWKS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need this because pyJWT does it with some inbuilt method ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes - https://pyjwt.readthedocs.io/en/latest/usage.html (re: |
||
ISSUER_CONFIGS = { | ||
f'{os.environ.get("custom_auth_url")}': { | ||
'jwks_uri': f'{os.environ.get("custom_auth_jwks_url")}', | ||
'allowed_audiences': f'{os.environ.get("custom_auth_client")}', | ||
}, | ||
} | ||
|
||
issuer_keys = {} | ||
|
||
|
||
# instead of re-downloading the public keys every time | ||
# we download them only on cold start | ||
# https://aws.amazon.com/blogs/compute/container-reuse-in-lambda/ | ||
def fetch_public_keys(): | ||
try: | ||
for issuer, issuer_config in ISSUER_CONFIGS.items(): | ||
jwks_response = requests.get(issuer_config['jwks_uri']) | ||
jwks_response.raise_for_status() | ||
jwks: dict = jwks_response.json() | ||
for key in jwks['keys']: | ||
value = { | ||
'issuer': issuer, | ||
'audience': issuer_config['allowed_audiences'], | ||
'jwk': jwk.construct(key), | ||
'public_key': jwk.construct(key).public_key(), | ||
} | ||
issuer_keys.update({key['kid']: value}) | ||
except Exception as e: | ||
raise Exception(f'Unable to fetch public keys due to {str(e)}') | ||
|
||
|
||
fetch_public_keys() | ||
|
||
# Options to validate the JWT token | ||
# Only modification from default is to turn off verify_at_hash as we don't provide the access token for this validation | ||
# Only modification from default is to turn off verify_aud as Cognito Access Token does not provide this claim | ||
jwt_options = { | ||
'verify_signature': True, | ||
'verify_aud': True, | ||
'verify_aud': False, | ||
'verify_iat': True, | ||
'verify_exp': True, | ||
'verify_nbf': True, | ||
'verify_iss': True, | ||
'verify_sub': True, | ||
'verify_jti': True, | ||
'verify_at_hash': False, | ||
'require_aud': True, | ||
'require_iat': True, | ||
'require_exp': True, | ||
'require_nbf': False, | ||
'require_iss': True, | ||
'require_sub': True, | ||
'require_jti': True, | ||
'require_at_hash': False, | ||
'leeway': 0, | ||
'require': ['iat', 'exp', 'iss', 'sub', 'jti'], | ||
} | ||
|
||
|
||
class JWTServices: | ||
petrkalos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@staticmethod | ||
def validate_jwt_token(jwt_token): | ||
def __init__(self, openid_config_path): | ||
# Get OpenID Config JSON | ||
self.openid_config = self._fetch_openid_config(openid_config_path) | ||
|
||
# Init pyJWT.JWKClient with JWK URI | ||
self.jwks_client = jwt.PyJWKClient(self.openid_config.get('jwks_uri')) | ||
|
||
def _fetch_openid_config(self, openid_config_path): | ||
response = requests.get(openid_config_path) | ||
response.raise_for_status() | ||
return response.json() | ||
|
||
def validate_jwt_token(self, jwt_token) -> dict: | ||
try: | ||
# Decode and verify the JWT token | ||
header = get_unverified_header(jwt_token) | ||
kid = header['kid'] | ||
if kid not in issuer_keys: | ||
logger.info('Public key not found in provided set of keys') | ||
# Retry Fetching the public certificates again in case rotation occurs and lambda has cached the publicKeys | ||
fetch_public_keys() | ||
if kid not in issuer_keys: | ||
raise Exception('Unauthorized') | ||
public_key = issuer_keys.get(kid) | ||
payload = decode( | ||
# get signing_key from JWT | ||
signing_key = self.jwks_client.get_signing_key_from_jwt(jwt_token) | ||
|
||
# Decode and Verify JWT | ||
payload = jwt.decode( | ||
jwt_token, | ||
public_key.get('jwk'), | ||
signing_key.key, | ||
algorithms=['RS256', 'HS256'], | ||
issuer=public_key.get('issuer'), | ||
audience=public_key.get('audience'), | ||
issuer=os.environ.get('custom_auth_url'), | ||
audience=os.environ.get('custom_auth_client'), | ||
leeway=0, | ||
options=jwt_options, | ||
) | ||
|
||
# verify client_id if Cognito JWT | ||
if os.environ['custom_auth_provider'] == 'Cognito' and payload['client_id'] != os.environ.get( | ||
'custom_auth_client' | ||
): | ||
raise Exception('Invalid Client ID in JWT Token') | ||
|
||
return payload | ||
except ExpiredSignatureError: | ||
except jwt.exceptions.ExpiredSignatureError as e: | ||
logger.error('JWT token has expired.') | ||
return None | ||
except JWTError as e: | ||
raise e | ||
except jwt.exceptions.PyJWTError as e: | ||
logger.error(f'JWT token validation failed: {str(e)}') | ||
return None | ||
raise e | ||
except Exception as e: | ||
logger.error(f'Failed to validate token - {str(e)}') | ||
return None | ||
raise e | ||
|
||
def validate_access_token(self, access_token) -> dict: | ||
# get UserInfo URI from OpenId Configuration | ||
user_info_url = self.openid_config.get('userinfo_endpoint') | ||
r = requests.get(user_info_url, headers={'Authorization': access_token}) | ||
r.raise_for_status() | ||
logger.debug(r.json()) | ||
return r.json() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,9 @@ | ||
certifi==2024.7.4 | ||
charset-normalizer==3.1.0 | ||
ecdsa==0.18.0 | ||
idna==3.7 | ||
pyasn1==0.5.0 | ||
python-jose==3.3.0 | ||
requests==2.32.2 | ||
rsa==4.9 | ||
six==1.16.0 | ||
urllib3==1.26.19 | ||
urllib3==1.26.19 | ||
pyjwt==2.9.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: That's why I don't like this approach, you have to remember to obfuscate stuff