From 1a1e58dcc08e7bb40d82c2e4576c63645e81f96a Mon Sep 17 00:00:00 2001 From: Marshall Brekka Date: Mon, 17 Jun 2024 12:09:59 -0700 Subject: [PATCH] Change websocket cleanup performance from O(n^2) to O(n). Move subscription tracking to a dict keyed by the underlying websocket object. Changes the cleanup from two nested for loops (the 2nd being the unsubscribe_from method), to a single loop with (the common case) deleting a single key from a dict. --- services/ui_backend_service/api/ws.py | 39 ++++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/services/ui_backend_service/api/ws.py b/services/ui_backend_service/api/ws.py index 7813f7f3..98fffa51 100644 --- a/services/ui_backend_service/api/ws.py +++ b/services/ui_backend_service/api/ws.py @@ -47,7 +47,7 @@ class Websocket(object): Example event: {"type": "UPDATE", "uuid": "myst3rySh4ck", "resource": "/runs", "data": {"foo": "bar"}} ''' - subscriptions: List[WSSubscription] = [] + _subscriptions: Dict[web.WebSocketResponse, List[WSSubscription]] = collections.defaultdict(list) def __init__(self, app, db, event_emitter=None, queue_ttl: int = WS_QUEUE_TTL_SECONDS, cache=None): self.event_emitter = event_emitter or AsyncIOEventEmitter() @@ -83,7 +83,7 @@ async def event_handler(self, operation: str, resources: List[str], data: Dict, a dictionary of filters used in the query when fetching complete data. """ # Check if event needs to be broadcast (if anyone is subscribed to the resource) - if any(subscription.resource in resources for subscription in self.subscriptions): + if any(subscription.resource in resources for subscription in self.subscriptions()): # load the data and postprocessor for broadcasting if table # is provided (otherwise data has already been loaded in advance) if table_name: @@ -108,18 +108,28 @@ async def event_handler(self, operation: str, resources: List[str], data: Dict, 'resources': resources, 'data': _data }) - for subscription in self.subscriptions: + for subscription in self.subscriptions(): try: if subscription.disconnected_ts and time.time() - subscription.disconnected_ts > WS_QUEUE_TTL_SECONDS: - await self.unsubscribe_from(subscription.ws, subscription.uuid) + # We can assume that all websockets (not just this UUID) are disconnected, don't filter by UUID as well. + await self.unsubscribe_from(subscription.ws) else: await self._event_subscription(subscription, operation, resources, _data) except ConnectionResetError: self.logger.debug("Trying to broadcast to a stale subscription. Unsubscribing") - await self.unsubscribe_from(subscription.ws, subscription.uuid) + await self.unsubscribe_from(subscription.ws) except Exception: self.logger.exception("Broadcasting to subscription failed") + def subscriptions(self): + # Grab all of the keys upfront and use that to iterate so that callers can + # safely modify the subscriptions dict while we are iterating through it. + # This is primarily useful when calling `unsubscribe_from` during the cleanup + # loop in the event handler. + for k in list(self._subscriptions.keys()): + for sub in self._subscriptions[k]: + yield sub + async def _event_subscription(self, subscription: WSSubscription, operation: str, resources: List[str], data: Dict): for resource in resources: if subscription.resource == resource: @@ -142,7 +152,7 @@ async def subscribe_to(self, ws, uuid: str, resource: str, since: int): subscription = WSSubscription( ws=ws, fullpath=resource, resource=_resource, query=query, uuid=uuid, filter=filter_fn, disconnected_ts=None) - self.subscriptions.append(subscription) + self._subscriptions[ws].append(subscription) # Send previous events that client might have missed due to disconnection if since: @@ -154,22 +164,25 @@ async def subscribe_to(self, ws, uuid: str, resource: str, since: int): ) async def unsubscribe_from(self, ws, uuid: str = None): + if ws not in self._subscriptions: + return if uuid: - self.subscriptions = list( - filter(lambda s: uuid != s.uuid or ws != s.ws, self.subscriptions)) + self._subscriptions[ws] = list( + filter(lambda s: uuid != s.uuid or ws != s.ws, self._subscriptions[ws])) + if len(self._subscriptions[ws]) == 0: + del self._subscriptions[ws] else: - self.subscriptions = list( - filter(lambda s: ws != s.ws, self.subscriptions)) + del self._subscriptions[ws] async def handle_disconnect(self, ws): """ Sets disconnected timestamp on websocket subscription without removing it from the list. Removing is handled by event_handler that checks for expired subscriptions before emitting """ - self.subscriptions = list( + self._subscriptions[ws] = list( map( - lambda sub: sub._replace(disconnected_ts=time.time()) if sub.ws == ws else sub, - self.subscriptions) + lambda sub: sub._replace(disconnected_ts=time.time()), + self._subscriptions[ws]) ) async def websocket_handler(self, request):