Skip to content

Commit

Permalink
Fixed notify events to include "user_service_uuid instead of user_ser…
Browse files Browse the repository at this point in the history
…vice (name change, not uuid inside)
  • Loading branch information
dkmstr committed Sep 24, 2023
1 parent 842626e commit eff1843
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 67 deletions.
4 changes: 2 additions & 2 deletions server/src/tests/REST/servers/test_events_login_logout.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_login(self) -> None:
# loginData = {
# 'token': 'server token', # Must be present on all events
# 'type': 'login', # MUST BE PRESENT
# 'user_service': 'uuid', # MUST BE PRESENT
# 'user_service_uuid': 'uuid', # MUST BE PRESENT
# 'username': 'username', # Optional
# }
# Returns:
Expand All @@ -78,7 +78,7 @@ def test_login(self) -> None:
data={
'token': self.server.token,
'type': 'login',
'user_service': self.user_service_managed.uuid,
'user_service_uuid': self.user_service_managed.uuid,
'username': 'local_user_name',
},
)
Expand Down
97 changes: 52 additions & 45 deletions server/src/uds/auths/SAML/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,52 +558,59 @@ def validateField(self, field: gui.TextField):
raise exceptions.ValidationError(f'Invalid pattern at {field.label}: {line}') from e

def processField(self, field: str, attributes: typing.Dict[str, typing.List[str]]) -> typing.List[str]:
res = []

def getAttr(attrName: str) -> typing.List[str]:
val: typing.List[str] = []
if '+' in attrName:
attrList = attrName.split('+')
# Check all attributes are present, and has only one value
attrValues = [ensure.is_list(attributes.get(a, [''])) for a in attrList]
if not all([len(v) <= 1 for v in attrValues]):
logger.warning('Attribute %s do not has exactly one value, skipping %s', attrName, line)
try:
res: typing.List[str] = []
def getAttr(attrName: str) -> typing.List[str]:
try:
val: typing.List[str] = []
if '+' in attrName:
attrList = attrName.split('+')
# Check all attributes are present, and has only one value
attrValues = [ensure.is_list(attributes.get(a, [''])) for a in attrList]
if not all([len(v) <= 1 for v in attrValues]):
logger.warning('Attribute %s do not has exactly one value, skipping %s', attrName, line)
return val

val = [''.join(v) for v in attrValues] # flatten
elif '**' in attrName:
# Prepend the value after : to attribute value before :
attr, prependable = attrName.split('**')
val = [prependable + a for a in ensure.is_list(attributes.get(attr, []))]
else:
val = ensure.is_list(attributes.get(attrName, []))
return val

val = [''.join(v) for v in attrValues] # flatten
elif ':' in attrName:
# Prepend the value after : to attribute value before :
attr, prependable = attrName.split(':')
val = [prependable + a for a in ensure.is_list(attributes.get(attr, []))]
else:
val = ensure.is_list(attributes.get(attrName, []))
return val

for line in field.splitlines():
equalPos = line.find('=')
if equalPos != -1:
attr, pattern = (line[:equalPos], line[equalPos + 1 :])
# if pattern do not have groups, define one with full re
if pattern.find('(') == -1:
pattern = '(' + pattern + ')'

val = getAttr(attr)

for v in val:
try:
logger.debug('Pattern: %s on value %s', pattern, v)
srch = re.search(pattern, v)
if srch is None:
continue
res.append(''.join(srch.groups()))
except Exception as e:
logger.warning('Invalid regular expression')
logger.debug(e)
break
else:
res += getAttr(line)
logger.debug('Result: %s', res)
return res
except Exception as e:
logger.warning('Error processing attribute %s (%s): %s', attrName, attributes, e)
return []

for line in field.splitlines():
equalPos = line.find('=')
if equalPos != -1:
attr, pattern = (line[:equalPos], line[equalPos + 1 :])
# if pattern do not have groups, define one with full re
if pattern.find('(') == -1:
pattern = '(' + pattern + ')'

val = getAttr(attr)

for v in val:
try:
logger.debug('Pattern: %s on value %s', pattern, v)
srch = re.search(pattern, v)
if srch is None:
continue
res.append(''.join(srch.groups()))
except Exception as e:
logger.warning('Invalid regular expression')
logger.debug(e)
break
else:
res += getAttr(line)
logger.debug('Result: %s', res)
return res
except Exception as e:
logger.warning('Error processing field %s (%s): %s', field, attributes, e)
return []

def getInfo(
self, parameters: typing.Mapping[str, str]
Expand Down
55 changes: 42 additions & 13 deletions server/src/uds/core/managers/servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _retrieveStats(server: 'models.Server') -> None:
requester.ServerApiRequester(best[0]).notifyRelease(userService)

return best

def assign(
self,
userService: 'models.UserService',
Expand Down Expand Up @@ -203,7 +203,7 @@ def assign(
if not userService.user:
raise exceptions.UDSException(_('No user assigned to service'))

# Look for existint user asignation through properties
# Look for existing user asignation through properties
prop_name = self.propertyName(userService.user)
now = model_utils.getSqlDatetime()

Expand Down Expand Up @@ -337,6 +337,19 @@ def release(

return types.servers.ServerCounter(serverCounter.server_uuid, serverCounter.counter - 1)

def notifyPreconnect(
self,
serverGroup: 'models.ServerGroup',
userService: 'models.UserService',
info: types.connections.ConnectionData,
) -> None:
"""
Notifies preconnect to server
"""
server = self.getServerAssignation(userService, serverGroup)
if server:
requester.ServerApiRequester(server).notifyPreconnect(userService, info)

def getAssignInformation(self, serverGroup: 'models.ServerGroup') -> typing.Dict[str, int]:
"""
Get usage information for a server group
Expand All @@ -355,6 +368,33 @@ def getAssignInformation(self, serverGroup: 'models.ServerGroup') -> typing.Dict
res[kk] = res.get(kk, 0) + v[1]
return res

def getServerAssignation(
self,
userService: 'models.UserService',
serverGroup: 'models.ServerGroup',
) -> typing.Optional['models.Server']:
"""
Returns the server assigned to an user service
Args:
userService: User service to get server from
serverGroup: Server group to get server from
Returns:
Server assigned to user service, or None if no server is assigned
"""
if not userService.user:
raise exceptions.UDSException(_('No user assigned to service'))

prop_name = self.propertyName(userService.user)
with serverGroup.properties as props:
info: typing.Optional[
types.servers.ServerCounter
] = types.servers.ServerCounter.fromIterable(props.get(prop_name))
if info is None:
return None
return models.Server.objects.get(uuid=info.server_uuid)

def doMaintenance(self, serverGroup: 'models.ServerGroup') -> None:
"""Realizes maintenance on server group
Expand All @@ -373,17 +413,6 @@ def doMaintenance(self, serverGroup: 'models.ServerGroup') -> None:
# User does not exists, remove it from counters
del serverGroup.properties[k]

def notifyPreconnect(
self,
server: 'models.Server',
userService: 'models.UserService',
info: types.connections.ConnectionData,
) -> None:
"""
Notifies preconnect to server
"""
requester.ServerApiRequester(server).notifyPreconnect(userService, info)

def processEvent(self, server: 'models.Server', data: typing.Dict[str, typing.Any]) -> typing.Any:
"""
Processes a notification FROM server
Expand Down
10 changes: 5 additions & 5 deletions server/src/uds/core/managers/servers_api/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def process_login(server: 'models.Server', data: typing.Dict[str, typing.Any]) -
"""Processes the REST login event from a server
data: {
'user_service': 'uuid of user service',
'user_service_uuid': 'uuid of user service',
'username': 'username',
'ticket': 'ticket if any' # optional
}
Expand All @@ -80,9 +80,9 @@ def process_login(server: 'models.Server', data: typing.Dict[str, typing.Any]) -
if 'ticket' in data:
ticket = models.TicketStore.get(data['ticket'], invalidate=True)
# If ticket is included, user_service can be inside ticket or in data
data['user_service'] = data.get('user_service', ticket['user_service'])
data['user_service_uuid'] = data.get('user_service_uuid', ticket['user_service_uuid'])

userService = models.UserService.objects.get(uuid=data['user_service'])
userService = models.UserService.objects.get(uuid=data['user_service_uuid'])
server.setActorVersion(userService)

if not userService.in_use: # If already logged in, do not add a second login (windows does this i.e.)
Expand Down Expand Up @@ -118,15 +118,15 @@ def process_logout(server: 'models.Server', data: typing.Dict[str, typing.Any])
"""Processes the REST logout event from a server
data: {
'user_service': 'uuid of user service',
'user_service_uuid': 'uuid of user service',
'session_id': 'session id',
}
Returns 'OK' if all went ok ({'result': 'OK', 'stamp': 'stamp'}), or an error if not ({'result': 'error', 'error': 'error description'}})
"""
userService = models.UserService.objects.get(uuid=data['user_service'])

session_id = data['session_id']
session_id = data['user_service_uuid']
userService.closeSession(session_id)

if userService.in_use: # If already logged out, do not add a second logout (windows does this i.e.)
Expand Down
3 changes: 2 additions & 1 deletion server/src/uds/core/managers/servers_api/requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

AUTH_TOKEN = 'X-TOKEN-AUTH'


# Restrainer decorator
# If server is restrained, it will return False
# If server is not restrained, it will execute the function and return it's result
Expand Down Expand Up @@ -230,7 +231,7 @@ def notifyRelease(self, userService: 'models.UserService') -> bool:
Notifies removal of user service to server
"""
logger.debug('Notifying release of service %s to server %s', userService.uuid, self.server.host)
self.post('release', {'userservice': userService.uuid})
self.post('release', types.connections.ReleaseRequest(userservice_uuid=userService.uuid).asDict())

return True

Expand Down
6 changes: 6 additions & 0 deletions server/src/uds/core/types/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ class AssignRequest(typing.NamedTuple):
def asDict(self) -> typing.Dict[str, 'str|int']:
return self._asdict()

class ReleaseRequest(typing.NamedTuple):
"""Information sent on a release request"""
userservice_uuid: str # UUID of userservice

def asDict(self) -> typing.Dict[str, str]:
return self._asdict()

class ConnectionData(typing.NamedTuple):
"""
Expand Down
2 changes: 1 addition & 1 deletion server/src/uds/models/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def notifyPreconnect(self, userService: 'UserService', info: 'types.connections.
Override this method if you need to do something before connecting to a service
(i.e. invoke notifyPreconnect using a Server, or whatever you need to do)
"""
pass
logger.warning('No actor notification available for user service %s', userService.friendly_name)

@property
def oldMaxAccountingMethod(self) -> bool:
Expand Down

0 comments on commit eff1843

Please sign in to comment.