diff --git a/python/ray/serve/_private/long_poll.py b/python/ray/serve/_private/long_poll.py index 6d0ea42a16dd..f3538913b76b 100644 --- a/python/ray/serve/_private/long_poll.py +++ b/python/ray/serve/_private/long_poll.py @@ -88,7 +88,11 @@ def __init__( self.key_listeners = key_listeners self.event_loop = call_in_event_loop self.snapshot_ids: Dict[KeyType, int] = { - key: -1 for key in self.key_listeners.keys() + # The initial snapshot id for each key is < 0, + # but real snapshot keys in the long poll host are always >= 0, + # so this will always trigger an initial update. + key: -1 + for key in self.key_listeners.keys() } self.is_running = True @@ -191,11 +195,9 @@ def __init__( ] = LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S, ): # Map object_key -> int - self.snapshot_ids: DefaultDict[KeyType, int] = defaultdict( - lambda: random.randint(0, 1_000_000) - ) + self.snapshot_ids: Dict[KeyType, int] = {} # Map object_key -> object - self.object_snapshots: Dict[KeyType, Any] = dict() + self.object_snapshots: Dict[KeyType, Any] = {} # Map object_key -> set(asyncio.Event waiting for updates) self.notifier_events: DefaultDict[KeyType, Set[asyncio.Event]] = defaultdict( set @@ -247,16 +249,24 @@ async def listen_for_change( immediately if the snapshot_ids are outdated, otherwise it will block until there's an update. """ - watched_keys = keys_to_snapshot_ids.keys() - existent_keys = set(watched_keys).intersection(set(self.snapshot_ids.keys())) - # If there are any keys with outdated snapshot ids, # return their updated values immediately. - updated_objects = { - key: UpdatedObject(self.object_snapshots[key], self.snapshot_ids[key]) - for key in existent_keys - if self.snapshot_ids[key] != keys_to_snapshot_ids[key] - } + updated_objects = {} + for key, client_snapshot_id in keys_to_snapshot_ids.items(): + try: + existing_id = self.snapshot_ids[key] + except KeyError: + # The caller may ask for keys that we don't know about (yet), + # just ignore them. + # This can happen when, for example, + # a deployment handle is manually created for an app + # that hasn't been deployed yet (by bypassing the safety checks). + continue + + if existing_id != client_snapshot_id: + updated_objects[key] = UpdatedObject( + self.object_snapshots[key], existing_id + ) if len(updated_objects) > 0: self._count_send(updated_objects) return updated_objects @@ -264,7 +274,7 @@ async def listen_for_change( # Otherwise, register asyncio events to be waited. async_task_to_events = {} async_task_to_watched_keys = {} - for key in watched_keys: + for key in keys_to_snapshot_ids.keys(): # Create a new asyncio event for this key. event = asyncio.Event() @@ -398,10 +408,16 @@ def notify_changed( object_key: KeyType, updated_object: Any, ): - self.snapshot_ids[object_key] += 1 + try: + self.snapshot_ids[object_key] += 1 + except KeyError: + # Initial snapshot id must be >= 0, so that the long poll client + # can send a negative initial snapshot id to get a fast update. + # They should also be randomized; + # see https://github.com/ray-project/ray/pull/45881#discussion_r1645243485 + self.snapshot_ids[object_key] = random.randint(0, 1_000_000) self.object_snapshots[object_key] = updated_object logger.debug(f"LongPollHost: Notify change for key {object_key}.") - if object_key in self.notifier_events: - for event in self.notifier_events.pop(object_key): - event.set() + for event in self.notifier_events.pop(object_key, set()): + event.set()