Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit d38c73e

Browse files
authored
Skip waiting for full state if a StateFilter does not require it (#12498)
If `StateFilter` specifies a state set which we will have regardless of state-syncing, then we may as well return it immediately.
1 parent 0fce474 commit d38c73e

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

changelog.d/12498.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Preparation for faster-room-join work: return subsets of room state which we already have, immediately.

synapse/storage/state.py

+59-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright 2014-2016 OpenMarket Ltd
2+
# Copyright 2022 The Matrix.org Foundation C.I.C.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
45
# you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@
1516
from typing import (
1617
TYPE_CHECKING,
1718
Awaitable,
19+
Callable,
1820
Collection,
1921
Dict,
2022
Iterable,
@@ -532,6 +534,44 @@ def approx_difference(self, other: "StateFilter") -> "StateFilter":
532534
new_all, new_excludes, new_wildcards, new_concrete_keys
533535
)
534536

537+
def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
538+
"""Check if we need to wait for full state to complete to calculate this state
539+
540+
If we have a state filter which is completely satisfied even with partial
541+
state, then we don't need to await_full_state before we can return it.
542+
543+
Args:
544+
is_mine_id: a callable which confirms if a given state_key matches a mxid
545+
of a local user
546+
"""
547+
548+
# TODO(faster_joins): it's not entirely clear that this is safe. In particular,
549+
# there may be circumstances in which we return a piece of state that, once we
550+
# resync the state, we discover is invalid. For example: if it turns out that
551+
# the sender of a piece of state wasn't actually in the room, then clearly that
552+
# state shouldn't have been returned.
553+
# We should at least add some tests around this to see what happens.
554+
555+
# if we haven't requested membership events, then it depends on the value of
556+
# 'include_others'
557+
if EventTypes.Member not in self.types:
558+
return self.include_others
559+
560+
# if we're looking for *all* membership events, then we have to wait
561+
member_state_keys = self.types[EventTypes.Member]
562+
if member_state_keys is None:
563+
return True
564+
565+
# otherwise, consider whose membership we are looking for. If it's entirely
566+
# local users, then we don't need to wait.
567+
for state_key in member_state_keys:
568+
if not is_mine_id(state_key):
569+
# remote user
570+
return True
571+
572+
# local users only
573+
return False
574+
535575

536576
_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
537577
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
@@ -544,6 +584,7 @@ class StateGroupStorage:
544584
"""High level interface to fetching state for event."""
545585

546586
def __init__(self, hs: "HomeServer", stores: "Databases"):
587+
self._is_mine_id = hs.is_mine_id
547588
self.stores = stores
548589
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
549590

@@ -675,7 +716,13 @@ async def get_state_for_events(
675716
RuntimeError if we don't have a state group for one or more of the events
676717
(ie they are outliers or unknown)
677718
"""
678-
event_to_groups = await self.get_state_group_for_events(event_ids)
719+
await_full_state = True
720+
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
721+
await_full_state = False
722+
723+
event_to_groups = await self.get_state_group_for_events(
724+
event_ids, await_full_state=await_full_state
725+
)
679726

680727
groups = set(event_to_groups.values())
681728
group_to_state = await self.stores.state._get_state_for_groups(
@@ -699,7 +746,9 @@ async def get_state_for_events(
699746
return {event: event_to_state[event] for event in event_ids}
700747

701748
async def get_state_ids_for_events(
702-
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
749+
self,
750+
event_ids: Collection[str],
751+
state_filter: Optional[StateFilter] = None,
703752
) -> Dict[str, StateMap[str]]:
704753
"""
705754
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -716,7 +765,13 @@ async def get_state_ids_for_events(
716765
RuntimeError if we don't have a state group for one or more of the events
717766
(ie they are outliers or unknown)
718767
"""
719-
event_to_groups = await self.get_state_group_for_events(event_ids)
768+
await_full_state = True
769+
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
770+
await_full_state = False
771+
772+
event_to_groups = await self.get_state_group_for_events(
773+
event_ids, await_full_state=await_full_state
774+
)
720775

721776
groups = set(event_to_groups.values())
722777
group_to_state = await self.stores.state._get_state_for_groups(
@@ -802,7 +857,7 @@ async def get_state_group_for_events(
802857
Args:
803858
event_ids: events to get state groups for
804859
await_full_state: if true, will block if we do not yet have complete
805-
state at this event.
860+
state at these events.
806861
"""
807862
if await_full_state:
808863
await self._partial_state_events_tracker.await_full_state(event_ids)

0 commit comments

Comments
 (0)