Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Reduce state pulled from DB due to sending typing and receipts over federation #12964

Merged
merged 6 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12964.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reduce state pulled from DB due to sending typing and receipts over federation.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
6 changes: 5 additions & 1 deletion synapse/federation/sender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.state = hs.get_state_handler()

self._storage_controllers = hs.get_storage_controllers()

self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id

Expand Down Expand Up @@ -602,7 +604,9 @@ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
room_id = receipt.room_id

# Work out which remote servers should be poked and poke them.
domains_set = await self.state.get_current_hosts_in_room(room_id)
domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
room_id
)
domains = [
d
for d in domains_set
Expand Down
7 changes: 5 additions & 2 deletions synapse/handlers/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class FollowerTypingHandler:

def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.server_name = hs.config.server.server_name
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
Expand Down Expand Up @@ -131,15 +132,17 @@ async def _push_remote(self, member: RoomMember, typing: bool) -> None:
return

try:
users = await self.store.get_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()

now = self.clock.time_msec()
self.wheel_timer.insert(
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
)

for domain in {get_domain_from_id(u) for u in users}:
hosts = await self._storage_controllers.state.get_current_hosts_in_room(
member.room_id
)
for domain in hosts:
if domain != self.server_name:
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
Expand Down
4 changes: 0 additions & 4 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,6 @@ async def get_current_users_in_room(
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)

async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)

async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
) -> FrozenSet[str]:
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _invalidate_state_caches(
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
if members_changed:
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
self._attempt_to_invalidate_cache(
"get_users_in_room_with_profiles", (room_id,)
)
Expand Down
8 changes: 8 additions & 0 deletions synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
List,
Mapping,
Optional,
Set,
Tuple,
)

Expand Down Expand Up @@ -482,3 +483,10 @@ async def get_current_state_event(
room_id, StateFilter.from_types((key,))
)
return state_map.get(key)

async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state."""

await self._partial_state_room_tracker.await_full_state(room_id)

return await self.stores.main.get_current_hosts_in_room(room_id)
37 changes: 37 additions & 0 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,43 @@ async def _check_host_room_membership(

return True

@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state."""

# First we check if we already have `get_users_in_room` in the cache, as
# we can just calculate result from that
users = self.get_users_in_room.cache.get_immediate(
(room_id,), None, update_metrics=False
)
if users is not None:
return {get_domain_from_id(u) for u in users}

if isinstance(self.database_engine, Sqlite3Engine):
# If we're using SQLite then let's just always use
# `get_users_in_room` rather than funky SQL.
Comment on lines +909 to +910
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is because the SQL doesn't work on SQLite (or it would need a different dialect?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, exactly

users = await self.get_users_in_room(room_id)
return {get_domain_from_id(u) for u in users}

# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.

def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
sql = """
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This regex seems to match get_domain_from_id for the cases I can think of. 👍

FROM current_state_events
WHERE
type = 'm.room.member'
AND membership = 'join'
AND room_id = ?
"""
txn.execute(sql, (room_id,))
return {d for d, in txn}

return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
)

async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
Expand Down
14 changes: 7 additions & 7 deletions tests/federation/test_federation_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@

class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call.
mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver(
state_handler=mock_state_handler,
hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
)

hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
return_value=make_awaitable(["test", "host2"])
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
)

return hs

@override_config({"send_federation": True})
def test_send_receipts(self):
mock_send_transaction = (
Expand Down
6 changes: 4 additions & 2 deletions tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ async def check_host_in_room(room_id: str, server_name: str) -> bool:

hs.get_event_auth_handler().check_host_in_room = check_host_in_room

def get_joined_hosts_for_room(room_id: str):
async def get_current_hosts_in_room(room_id: str):
return {member.domain for member in self.room_members}

self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
hs.get_storage_controllers().state.get_current_hosts_in_room = (
get_current_hosts_in_room
)

async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
Expand Down