Skip to content

Commit

Permalink
Basic OAuth2 with code interchange done
Browse files Browse the repository at this point in the history
  • Loading branch information
dkmstr committed Oct 19, 2023
1 parent 9419e0f commit ad97551
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 29 deletions.
84 changes: 60 additions & 24 deletions server/src/uds/auths/OAuth2/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,25 +178,36 @@ class OAuth2Authenticator(auths.Authenticator):
required=False,
tab=types.ui.Tab.ADVANCED,
)
# Attributes info fields
userAttribute = gui.TextField(
length=64,
label=_('Username attribute'),

userNameAttr = gui.TextField(
length=2048,
lines=2,
label=_('User name attrs'),
order=100,
tooltip=_('Attribute that contains the username'),
tooltip=_('Fields from where to extract user name'),
required=True,
tab=_('Attributes'),
)
groupsAttributes = gui.TextField(
length=64,
label=_('Groups attribute'),

groupNameAttr = gui.TextField(
length=2048,
lines=2,
label=_('Group name attrs'),
order=101,
tooltip=_('Attribute that contains the groups'),
required=True,
tooltip=_('Fields from where to extract the groups'),
required=False,
tab=_('Attributes'),
)

realNameAttr = gui.TextField(
length=2048,
lines=2,
label=_('Real name attrs'),
order=102,
tooltip=_('Fields from where to extract the real name'),
required=False,
tab=_('Attributes'),
)


def initialize(self, values: typing.Optional[typing.Dict[str, typing.Any]]) -> None:
if not values:
Expand All @@ -206,10 +217,9 @@ def initialize(self, values: typing.Optional[typing.Dict[str, typing.Any]]) -> N
raise exceptions.ValidationError(
gettext('This kind of Authenticator does not support white spaces on field NAME')
)

auth_utils.validateRegexField(self.userAttribute)
auth_utils.validateRegexField(self.userAttribute)


auth_utils.validateRegexField(self.userNameAttr)
auth_utils.validateRegexField(self.userNameAttr)

if self.responseType.value == 'code':
if self.commonGroups.value.strip() == '':
Expand Down Expand Up @@ -266,7 +276,6 @@ def _requestToken(self, request: 'HttpRequest', code: str) -> TokenInfo:
raise Exception('Error requesting token: {}'.format(req.text))

return TokenInfo.fromJson(req.json())


def authCallback(
self,
Expand All @@ -291,6 +300,18 @@ def getJavascript(self, request: 'HttpRequest') -> typing.Optional[str]:
"""
return f'window.location="{self._getLoginURL(request)}";'

def getGroups(self, username: str, groupsManager: 'auths.GroupsManager'):
data = self.storage.getPickle(username)
if not data:
return
groupsManager.validate(data[1])

def getRealName(self, username: str) -> str:
data = self.storage.getPickle(username)
if not data:
return username
return data[0]

def authCallbackCode(
self,
parameters: 'types.auth.AuthCallbackParams',
Expand All @@ -313,26 +334,41 @@ def authCallbackCode(
return auths.FAILED_AUTH

token = self._requestToken(request, code)

userInfo: typing.Dict[str, typing.Any]

if self.infoEndpoint.value.strip() == '':
if not token.info:
raise Exception('No user info received')
userInfo = token.info
else:
# Get user info
req = requests.get(self.infoEndpoint.value, headers={'Authorization': 'Bearer ' + token.access_token}, timeout=consts.COMMS_TIMEOUT)
req = requests.get(
self.infoEndpoint.value,
headers={'Authorization': 'Bearer ' + token.access_token},
timeout=consts.COMMS_TIMEOUT,
)
if not req.ok:
raise Exception('Error requesting user info: {}'.format(req.text))
userInfo = req.json()

username = ''.join(auth_utils.processRegexField(self.userNameAttr.value, userInfo)).replace(' ', '_')
if len(username) == 0:
raise Exception('No username received')

realName = ''.join(auth_utils.processRegexField(self.realNameAttr.value, userInfo))

# Get groups
groups = auth_utils.processRegexField(self.groupNameAttr.value, userInfo)
# Append common groups
groups.extend(self.commonGroups.value.split(','))

# store groups for this username at storage, so we can check it at a later stage
self.storage.putPickle(username, [realName, groups])

# Validate common groups
groups = self.commonGroups.value.split(',')
gm.validate(groups)

# We don't mind about the token, we only need to authenticate user
# and if we are here, the user is authenticated, so we can return SUCCESS_AUTH
return auths.AuthenticationResult(
auths.AuthenticationSuccess.OK, username=parameters.get_params.get('username', '')
)
return auths.AuthenticationResult(auths.AuthenticationSuccess.OK, username=username)
18 changes: 13 additions & 5 deletions server/src/uds/core/util/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@

logger = logging.getLogger(__name__)


def validateRegexField(field: ui.gui.TextField, fieldValue: typing.Optional[str] = None):
"""
Validates the multi line fields refering to attributes
"""
value: str = fieldValue or field.value
if value.strip() == '':
return # Ok, empty

for line in value.splitlines():
if line.find('=') != -1:
pattern = line.split('=')[0:2][1]
Expand All @@ -53,15 +57,21 @@ def validateRegexField(field: ui.gui.TextField, fieldValue: typing.Optional[str]
except Exception as e:
raise exceptions.ValidationError(f'Invalid pattern at {field.label}: {line}') from e

def processRegexField(field: str, attributes: typing.Mapping[str, typing.Union[str, typing.List[str]]]) -> typing.List[str]:

def processRegexField(
field: str, attributes: typing.Mapping[str, typing.Union[str, typing.List[str]]]
) -> typing.List[str]:
"""Proccesses a field, that can be a multiline field, and returns a list of values
Args:
field (str): Field to process
attributes (typing.Dict[str, typing.List[str]]): Attributes to use on processing
"""
try:
res: typing.List[str] = []
field = field.strip()
if field == '':
return res

def getAttr(attrName: str) -> typing.List[str]:
try:
Expand All @@ -71,9 +81,7 @@ def getAttr(attrName: str) -> typing.List[str]:
# Check all attributes are present, and has only one value
attrValues = [ensure.is_list(attributes.get(a, [''])) for a in attrList]
if not all([len(v) <= 1 for v in attrValues]):
logger.warning(
'Attribute %s do not has exactly one value, skipping %s', attrName, line
)
logger.warning('Attribute %s do not has exactly one value, skipping %s', attrName, line)
return val

val = [''.join(v) for v in attrValues] # flatten
Expand Down

0 comments on commit ad97551

Please sign in to comment.