Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for stateless code flow #41

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
description='OpenID Connect Provider (OP) library in Python.',
install_requires=[
'oic >= 1.2.1',
'pycryptodomex',
],
extras_require={
'mongo': 'pymongo',
Expand Down
113 changes: 88 additions & 25 deletions src/pyop/authz_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .exceptions import InvalidRefreshToken
from .exceptions import InvalidScope
from .exceptions import InvalidSubjectIdentifier
from .storage import StatelessWrapper
from .util import requested_scope_is_allowed

logger = logging.getLogger(__name__)
Expand All @@ -24,13 +25,15 @@ def rand_str():

class AuthorizationState(object):
KEY_AUTHORIZATION_REQUEST = 'auth_req'
KEY_USER_INFO = 'user_info'
KEY_EXTRA_ID_TOKEN_CLAIMS = 'extra_id_token_claims'

def __init__(self, subject_identifier_factory, authorization_code_db=None, access_token_db=None,
refresh_token_db=None, subject_identifier_db=None, *,
authorization_code_lifetime=600, access_token_lifetime=3600, refresh_token_lifetime=None,
refresh_token_threshold=None):
# type: (se_leg_op.token_state.SubjectIdentifierFactory, Mapping[str, Any], Mapping[str, Any],
# Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None
# Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None
"""
:param subject_identifier_factory: callable to use when construction subject identifiers
:param authorization_code_db: database for storing authorization codes, defaults to in-memory
Expand Down Expand Up @@ -77,10 +80,28 @@ def __init__(self, subject_identifier_factory, authorization_code_db=None, acces
"""
Mapping of user id's to subject identifiers.
"""
self.subject_identifiers = subject_identifier_db if subject_identifier_db is not None else {}

def create_authorization_code(self, authorization_request, subject_identifier, scope=None):
# type: (AuthorizationRequest, str, Optional[List[str]]) -> str
self.stateless = (
isinstance(self.authorization_codes, StatelessWrapper)
or isinstance(self.access_tokens, StatelessWrapper)
or isinstance(self.refresh_tokens, StatelessWrapper)
)
self.subject_identifiers = (
{}
if self.stateless
else subject_identifier_db
if subject_identifier_db is not None
else {}
)

def create_authorization_code(
self,
authorization_request,
subject_identifier,
scope=None,
user_info=None,
extra_id_token_claims=None,
):
# type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict], Optional[Mappings[str, Union[str, List[str]]]]) -> str
"""
Creates an authorization code bound to the authorization request and the authenticated user identified
by the subject identifier.
Expand All @@ -92,21 +113,29 @@ def create_authorization_code(self, authorization_request, subject_identifier, s
scope = ' '.join(scope or authorization_request['scope'])
logger.debug('creating authz code for scope=%s', scope)

authorization_code = rand_str()
authz_info = {
'used': False,
'exp': int(time.time()) + self.authorization_code_lifetime,
'sub': subject_identifier,
'granted_scope': scope,
self.KEY_AUTHORIZATION_REQUEST: authorization_request.to_dict()
}
self.authorization_codes[authorization_code] = authz_info

if self.stateless:
if user_info:
authz_info[self.KEY_USER_INFO] = user_info
authz_info[self.KEY_EXTRA_ID_TOKEN_CLAIMS] = extra_id_token_claims or {}
authorization_code = self.authorization_codes.pack(authz_info)
else:
authorization_code = rand_str()
self.authorization_codes[authorization_code] = authz_info

logger.debug('new authz_code=%s to client_id=%s for sub=%s valid_until=%s', authorization_code,
authorization_request['client_id'], subject_identifier, authz_info['exp'])
return authorization_code

def create_access_token(self, authorization_request, subject_identifier, scope=None):
# type: (AuthorizationRequest, str, Optional[List[str]]) -> se_leg_op.access_token.AccessToken
def create_access_token(self, authorization_request, subject_identifier, scope=None, user_info=None):
# type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict]) -> se_leg_op.access_token.AccessToken
"""
Creates an access token bound to the authentication request and the authenticated user identified by the
subject identifier.
Expand All @@ -116,15 +145,15 @@ def create_access_token(self, authorization_request, subject_identifier, scope=N

scope = scope or authorization_request['scope']

return self._create_access_token(subject_identifier, authorization_request.to_dict(), ' '.join(scope))
return self._create_access_token(subject_identifier, authorization_request.to_dict(), ' '.join(scope),
user_info=user_info)

def _create_access_token(self, subject_identifier, auth_req, granted_scope, current_scope=None):
# type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str]) -> se_leg_op.access_token.AccessToken
def _create_access_token(self, subject_identifier, auth_req, granted_scope, current_scope=None,
user_info=None):
# type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str], Optional[dict]) -> se_leg_op.access_token.AccessToken
"""
Creates an access token bound to the subject identifier, client id and requested scope.
"""
access_token = AccessToken(rand_str(), self.access_token_lifetime)

scope = current_scope or granted_scope
logger.debug('creating access token for scope=%s', scope)

Expand All @@ -136,13 +165,21 @@ def _create_access_token(self, subject_identifier, auth_req, granted_scope, curr
'aud': [auth_req['client_id']],
'scope': scope,
'granted_scope': granted_scope,
'token_type': access_token.BEARER_TOKEN_TYPE,
'token_type': AccessToken.BEARER_TOKEN_TYPE,
self.KEY_AUTHORIZATION_REQUEST: auth_req
}
self.access_tokens[access_token.value] = authz_info

if self.stateless:
if user_info:
authz_info[self.KEY_USER_INFO] = user_info
access_token_val = self.access_tokens.pack(authz_info)
else:
access_token_val = rand_str()
self.access_tokens[access_token_val] = authz_info

logger.debug('new access_token=%s to client_id=%s for sub=%s valid_until=%s',
access_token.value, auth_req['client_id'], subject_identifier, authz_info['exp'])
access_token_val, auth_req['client_id'], subject_identifier, authz_info['exp'])
access_token = AccessToken(access_token_val, self.access_token_lifetime)
return access_token

def exchange_code_for_token(self, authorization_code):
Expand All @@ -165,7 +202,8 @@ def exchange_code_for_token(self, authorization_code):
authz_info['used'] = True

access_token = self._create_access_token(authz_info['sub'], authz_info[self.KEY_AUTHORIZATION_REQUEST],
authz_info['granted_scope'])
authz_info['granted_scope'],
user_info=authz_info.get(self.KEY_USER_INFO))

logger.debug('authz_code=%s exchanged to access_token=%s', authorization_code, access_token.value)
return access_token
Expand Down Expand Up @@ -199,9 +237,13 @@ def create_refresh_token(self, access_token_value):
logger.debug('no refresh token issued for for access_token=%s', access_token_value)
return None

refresh_token = rand_str()
authz_info = {'access_token': access_token_value, 'exp': int(time.time()) + self.refresh_token_lifetime}
self.refresh_tokens[refresh_token] = authz_info

if self.stateless:
refresh_token = self.refresh_tokens.pack(authz_info)
else:
refresh_token = rand_str()
self.refresh_tokens[refresh_token] = authz_info

logger.debug('issued refresh_token=%s expiring=%d for access_token=%s', refresh_token, authz_info['exp'],
access_token_value)
Expand Down Expand Up @@ -235,7 +277,8 @@ def use_refresh_token(self, refresh_token, scope=None):
scope = authz_info['granted_scope']

new_access_token = self._create_access_token(authz_info['sub'], authz_info[self.KEY_AUTHORIZATION_REQUEST],
authz_info['granted_scope'], scope)
authz_info['granted_scope'], scope,
user_info=authz_info.get(self.KEY_USER_INFO))

new_refresh_token = None
if self.refresh_token_threshold \
Expand Down Expand Up @@ -293,7 +336,7 @@ def get_subject_identifier(self, subject_type, user_id, sector_identifier=None):
raise ValueError('Unknown subject_type={}'.format(subject_type))

def _is_valid_subject_identifier(self, sub):
# type: (str) -> str
# type: (str) -> bool
"""
Determines whether the subject identifier is known.
"""
Expand All @@ -307,13 +350,33 @@ def _is_valid_subject_identifier(self, sub):
def get_user_id_for_subject_identifier(self, subject_identifier):
for user_id, subject_identifiers in self.subject_identifiers.items():
is_public_sub = 'public' in subject_identifiers and subject_identifier == subject_identifiers['public']
is_pairwise_sub = 'pairwise' in subject_identifiers and subject_identifier in subject_identifiers[
'pairwise']
is_pairwise_sub = 'pairwise' in subject_identifiers and subject_identifier in subject_identifiers['pairwise']
if is_public_sub or is_pairwise_sub:
return user_id

raise InvalidSubjectIdentifier('{} unknown'.format(subject_identifier))

def get_user_info_for_code(self, authorization_code):
# type: (str) -> dict
if authorization_code not in self.authorization_codes:
raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))

return self.authorization_codes[authorization_code].get(self.KEY_USER_INFO)

def get_extra_io_token_claims_for_code(self, authorization_code):
# type: (str) -> dict
if authorization_code not in self.authorization_codes:
raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))

return self.authorization_codes[authorization_code].get(self.KEY_EXTRA_ID_TOKEN_CLAIMS)

def get_user_info_for_access_token(self, access_token):
# type: (str) -> dict
if access_token not in self.access_tokens:
raise InvalidAccessToken('{} unknown'.format(access_token))

return self.access_tokens[access_token].get(self.KEY_USER_INFO)

def get_authorization_request_for_code(self, authorization_code):
# type: (str) -> AuthorizationRequest
if authorization_code not in self.authorization_codes:
Expand All @@ -323,7 +386,7 @@ def get_authorization_request_for_code(self, authorization_code):
self.authorization_codes[authorization_code][self.KEY_AUTHORIZATION_REQUEST])

def get_authorization_request_for_access_token(self, access_token_value):
# type: (str) ->
# type: (str) ->
if access_token_value not in self.access_tokens:
raise InvalidAccessToken('{} unknown'.format(access_token_value))

Expand Down
74 changes: 74 additions & 0 deletions src/pyop/crypto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import base64
import hashlib

from Cryptodome import Random
from Cryptodome.Cipher import AES


class _AESCipher(object):
"""
This class will perform AES encryption/decryption with a keylength of 256.

@see: http://stackoverflow.com/questions/12524994/encrypt-decrypt-using-pycrypto-aes-256
"""

def __init__(self, key):
"""
Constructor

:type key: str

:param key: The key used for encryption and decryption. The longer key the better.
"""
self.bs = 32
self.key = hashlib.sha256(key.encode()).digest()

def encrypt(self, raw):
"""
Encryptes the parameter raw.

:type raw: bytes
:rtype: str

:param: bytes to be encrypted.

:return: A base 64 encoded string.
"""
raw = self._pad(raw)
iv = Random.new().read(AES.block_size)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return base64.urlsafe_b64encode(iv + cipher.encrypt(raw))

def decrypt(self, enc):
"""
Decryptes the parameter enc.

:type enc: bytes
:rtype: bytes

:param: The value to be decrypted.
:return: The decrypted value.
"""
enc = base64.urlsafe_b64decode(enc)
iv = enc[:AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size:]))

def _pad(self, b):
"""
Will padd the param to be of the correct length for the encryption alg.

:type b: bytes
:rtype: bytes
"""
return b + (self.bs - len(b) % self.bs) * chr(self.bs - len(b) % self.bs).encode("UTF-8")

@staticmethod
def _unpad(b):
"""
Removes the padding performed by the method _pad.

:type b: bytes
:rtype: bytes
"""
return b[:-ord(b[len(b) - 1:])]
Loading