Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Avoid looping over all snapshot ids for each long poll request #45881

47 changes: 29 additions & 18 deletions python/ray/serve/_private/long_poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def __init__(
self.key_listeners = key_listeners
self.event_loop = call_in_event_loop
self.snapshot_ids: Dict[KeyType, int] = {
key: -1 for key in self.key_listeners.keys()
# The initial snapshot id for each key is < 0,
# but real snapshot keys in the long poll host are always >= 0,
# so this will always trigger an initial update.
shrekris-anyscale marked this conversation as resolved.
Show resolved Hide resolved
key: -1
for key in self.key_listeners.keys()
}
self.is_running = True

Expand Down Expand Up @@ -191,11 +195,9 @@ def __init__(
] = LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S,
):
# Map object_key -> int
self.snapshot_ids: DefaultDict[KeyType, int] = defaultdict(
lambda: random.randint(0, 1_000_000)
)
self.snapshot_ids: Dict[KeyType, int] = {}
# Map object_key -> object
self.object_snapshots: Dict[KeyType, Any] = dict()
self.object_snapshots: Dict[KeyType, Any] = {}
# Map object_key -> set(asyncio.Event waiting for updates)
self.notifier_events: DefaultDict[KeyType, Set[asyncio.Event]] = defaultdict(
set
Expand Down Expand Up @@ -247,24 +249,29 @@ async def listen_for_change(
immediately if the snapshot_ids are outdated, otherwise it will block
until there's an update.
"""
watched_keys = keys_to_snapshot_ids.keys()
existent_keys = set(watched_keys).intersection(set(self.snapshot_ids.keys()))
shrekris-anyscale marked this conversation as resolved.
Show resolved Hide resolved

# If there are any keys with outdated snapshot ids,
# return their updated values immediately.
updated_objects = {
key: UpdatedObject(self.object_snapshots[key], self.snapshot_ids[key])
for key in existent_keys
if self.snapshot_ids[key] != keys_to_snapshot_ids[key]
}
updated_objects = {}
for key, snapshot_id in keys_to_snapshot_ids.items():
JoshKarpel marked this conversation as resolved.
Show resolved Hide resolved
try:
existing_id = self.snapshot_ids[key]
except KeyError:
# The caller may ask for keys that we don't know about (yet),
# just ignore them.
shrekris-anyscale marked this conversation as resolved.
Show resolved Hide resolved
continue

if existing_id != snapshot_id:
updated_objects[key] = UpdatedObject(
self.object_snapshots[key], existing_id
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid raising and catching an error by using a conditional block.

Suggested change
try:
existing_id = self.snapshot_ids[key]
except KeyError:
# The caller may ask for keys that we don't know about (yet),
# just ignore them.
continue
if existing_id != snapshot_id:
updated_objects[key] = UpdatedObject(
self.object_snapshots[key], existing_id
)
# The caller may ask for keys that we don't know about (yet),
# just ignore them.
if key in self.snapshot_ids:
latest_id = self.snapshot_ids[key]
if latest_id != snapshot_id:
updated_objects[key] = UpdatedObject(
self.object_snapshots[key], latest_id
)

if len(updated_objects) > 0:
self._count_send(updated_objects)
return updated_objects

# Otherwise, register asyncio events to be waited.
async_task_to_events = {}
async_task_to_watched_keys = {}
for key in watched_keys:
for key in keys_to_snapshot_ids.keys():
# Create a new asyncio event for this key.
event = asyncio.Event()

Expand Down Expand Up @@ -398,10 +405,14 @@ def notify_changed(
object_key: KeyType,
updated_object: Any,
):
self.snapshot_ids[object_key] += 1
try:
self.snapshot_ids[object_key] += 1
Comment on lines -401 to +412
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a defaultdict anymore :(

except KeyError:
# Initial snapshot id must be >= 0, so that the long poll client
# can send a negative initial snapshot id to get a fast update.
self.snapshot_ids[object_key] = 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was previously a random number https://github.com/ray-project/ray/pull/45881/files#diff-f138b21f7ddcd7d61c0b2704c8b828b9bbe7eb5021531e2c7fabeb20ec322e1aL195 but I couldn't figure out why - I can of course use a random number here but this seems simpler if that's not necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shrekris-anyscale do you know why this was a random number / is that necessary?

Copy link
Contributor

@shrekris-anyscale shrekris-anyscale Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We set it to a random number to handle an edge case when the controller crashes and recovers. Suppose the controller always starts with ID 0, and it's currently at ID 1. Suppose all the clients are also at ID 1. Now suppose:

  1. The ServeController (which runs the LongPollHost) crashes and restarts. The LongPollHost's snapshot IDs are reset to 0.
  2. All clients– except 1 slow client– reconnect to the LongPollHost on the ServeController. Their snapshot IDs are still 1, so the controller propagates an update, and all the connected clients' snapshot IDs are set back to 0.
  3. One of the connected clients sends an update to the LongPollHost using notify_changed. The controller bumps the snapshot ID to 1 and updates the connected clients.
  4. The slow client finally connects to the controller. Its ID is also 1, so it doesn't receive the update.

To correctly avoid this edge case, we should cache the snapshot IDs and restore them when the controller recovers. That's pretty complex though, so instead we initialize the snapshot IDs to a random number between 0 and 1,000,000. That makes this edge case very unlikely.

Could we switch back to using a default dict with the random number generator as the factory function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! That totally makes sense, and I will restore that behavior.

Could we switch back to using a default dict with the random number generator as the factory function?

Let me play around with this, I might be able to make it work

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I just remembered why I changed it do a plain dict - the problem is that snapshot_ids is a defaultdict, but object_snapshots isn't, so I was trying to get items from one that didn't exist in the other. Now I'm wondering if those two mappings can be combined, since they have the same keys 🤔

self.object_snapshots[object_key] = updated_object
logger.debug(f"LongPollHost: Notify change for key {object_key}.")

if object_key in self.notifier_events:
for event in self.notifier_events.pop(object_key):
event.set()
for event in self.notifier_events.pop(object_key, set()):
event.set()
Loading