Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change websocket cleanup performance from O(n^2) to O(n). #429

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions services/ui_backend_service/api/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Websocket(object):
Example event:
{"type": "UPDATE", "uuid": "myst3rySh4ck", "resource": "/runs", "data": {"foo": "bar"}}
'''
subscriptions: List[WSSubscription] = []
_subscriptions: Dict[web.WebSocketResponse, List[WSSubscription]] = collections.defaultdict(list)

def __init__(self, app, db, event_emitter=None, queue_ttl: int = WS_QUEUE_TTL_SECONDS, cache=None):
self.event_emitter = event_emitter or AsyncIOEventEmitter()
Expand Down Expand Up @@ -83,7 +83,7 @@ async def event_handler(self, operation: str, resources: List[str], data: Dict,
a dictionary of filters used in the query when fetching complete data.
"""
# Check if event needs to be broadcast (if anyone is subscribed to the resource)
if any(subscription.resource in resources for subscription in self.subscriptions):
if any(subscription.resource in resources for subscription in self.subscriptions()):
# load the data and postprocessor for broadcasting if table
# is provided (otherwise data has already been loaded in advance)
if table_name:
Expand All @@ -108,18 +108,28 @@ async def event_handler(self, operation: str, resources: List[str], data: Dict,
'resources': resources,
'data': _data
})
for subscription in self.subscriptions:
for subscription in self.subscriptions():
try:
if subscription.disconnected_ts and time.time() - subscription.disconnected_ts > WS_QUEUE_TTL_SECONDS:
await self.unsubscribe_from(subscription.ws, subscription.uuid)
# We can assume that all websockets (not just this UUID) are disconnected, don't filter by UUID as well.
await self.unsubscribe_from(subscription.ws)
else:
await self._event_subscription(subscription, operation, resources, _data)
except ConnectionResetError:
self.logger.debug("Trying to broadcast to a stale subscription. Unsubscribing")
await self.unsubscribe_from(subscription.ws, subscription.uuid)
await self.unsubscribe_from(subscription.ws)
except Exception:
self.logger.exception("Broadcasting to subscription failed")

def subscriptions(self):
# Grab all of the keys upfront and use that to iterate so that callers can
# safely modify the subscriptions dict while we are iterating through it.
# This is primarily useful when calling `unsubscribe_from` during the cleanup
# loop in the event handler.
for k in list(self._subscriptions.keys()):
for sub in self._subscriptions[k]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: no KeyError possible due to self._subscriptions being a defaultdict(list)

yield sub

async def _event_subscription(self, subscription: WSSubscription, operation: str, resources: List[str], data: Dict):
for resource in resources:
if subscription.resource == resource:
Expand All @@ -142,7 +152,7 @@ async def subscribe_to(self, ws, uuid: str, resource: str, since: int):
subscription = WSSubscription(
ws=ws, fullpath=resource, resource=_resource, query=query, uuid=uuid,
filter=filter_fn, disconnected_ts=None)
self.subscriptions.append(subscription)
self._subscriptions[ws].append(subscription)

# Send previous events that client might have missed due to disconnection
if since:
Expand All @@ -154,22 +164,25 @@ async def subscribe_to(self, ws, uuid: str, resource: str, since: int):
)

async def unsubscribe_from(self, ws, uuid: str = None):
if ws not in self._subscriptions:
return
if uuid:
self.subscriptions = list(
filter(lambda s: uuid != s.uuid or ws != s.ws, self.subscriptions))
self._subscriptions[ws] = list(
filter(lambda s: uuid != s.uuid or ws != s.ws, self._subscriptions[ws]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the ws != s.ws check necessary anymore as the subscription is keyed per websocket?

if len(self._subscriptions[ws]) == 0:
del self._subscriptions[ws]
else:
self.subscriptions = list(
filter(lambda s: ws != s.ws, self.subscriptions))
del self._subscriptions[ws]

async def handle_disconnect(self, ws):
"""
Sets disconnected timestamp on websocket subscription without removing it from the list.
Removing is handled by event_handler that checks for expired subscriptions before emitting
"""
self.subscriptions = list(
self._subscriptions[ws] = list(
map(
lambda sub: sub._replace(disconnected_ts=time.time()) if sub.ws == ws else sub,
self.subscriptions)
lambda sub: sub._replace(disconnected_ts=time.time()),
self._subscriptions[ws])
)

async def websocket_handler(self, request):
Expand Down
Loading