Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Don't pull out state in compute_event_context for unconflicted state #13267

Merged
merged 8 commits into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,12 @@ async def cache_joined_hosts_for_event(
if state_entry.state_group in self._external_cache_joined_hosts_updates:
return

joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
state = await state_entry.get_state(
self._storage_controllers.state, StateFilter.all()
)
joined_hosts = await self.store.get_joined_hosts(
event.room_id, state, state_entry
)

# Note that the expiry times must be larger than the expiry time in
# _external_cache_joined_hosts_updates.
Expand Down
97 changes: 64 additions & 33 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import heapq
import logging
from collections import defaultdict
from enum import auto
from optparse import Option
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -47,13 +49,15 @@
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.controllers import StateStorageController
from synapse.storage.databases.main import DataStore

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,13 +91,16 @@ class _StateCacheEntry:

def __init__(
self,
state: StateMap[str],
state: Optional[StateMap[str]],
state_group: Optional[int],
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
):
if state is None and state_group is None:
raise Exception("Either state or state group must be not None")

# A map from (type, state_key) to event_id.
self.state = frozendict(state)
self.state = frozendict(state) if state is not None else None

# the ID of a state group if one and only one is involved.
# otherwise, None otherwise?
Expand All @@ -102,8 +109,26 @@ def __init__(
self.prev_group = prev_group
self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None

async def get_state(
self,
state_storage: "StateStorageController",
state_filter: Optional["StateFilter"] = None,
) -> StateMap[str]:
"""Get the state map for this entry, either from the in-memory state or
looking up the state group in the DB.
"""

if self.state is not None:
return self.state

assert self.state_group is not None

return await state_storage.get_state_ids_for_group(
self.state_group, state_filter
)

def __len__(self) -> int:
return len(self.state)
return len(self.state) if self.state else 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sanity check: is there ever a situation where self.state is not None but len(self.state) == 0? (I.e. is there a risk of introducing confusion here?) I don't think so, since every state map ought to have m.room.create. Well, I guess there's the state before the create event... but we shouldn't need to refer to that very often.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow, if self.state is falsey but not None then len(self.state) should be zero anyway?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, actually, it should be 1 as __len__ is used by the cache to guestimate the size of the entry. Have changed and commented



class StateHandler:
Expand Down Expand Up @@ -141,7 +166,7 @@ async def get_current_state_ids(
"""
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return ret.state
return await ret.get_state(self._state_storage_controller, StateFilter.all())

async def get_current_users_in_room(
self, room_id: str, latest_event_ids: List[str]
Expand All @@ -165,7 +190,8 @@ async def get_current_users_in_room(

logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_users_from_state(room_id, state, entry)

async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
Expand All @@ -180,7 +206,8 @@ async def get_hosts_in_room_at_events(
The hosts in the room at the given events
"""
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
return await self.store.get_joined_hosts(room_id, entry)
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_hosts(room_id, state, entry)

async def compute_event_context(
self,
Expand Down Expand Up @@ -215,10 +242,19 @@ async def compute_event_context(
#
if state_ids_before_event:
# if we're given the state before the event, then we use that
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
entry = None

# .. though we need to get a state group for it.
state_group_before_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=None,
delta_ids=None,
current_state_ids=state_ids_before_event,
)
)

else:
# otherwise, we'll need to resolve the state across the prev_events.
Expand Down Expand Up @@ -252,36 +288,27 @@ async def compute_event_context(
await_full_state=False,
)

state_ids_before_event = entry.state
state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids

#
# make sure that we have a state group at that point. If it's not a state event,
# that will be the state group for the new event. If it *is* a state event,
# it might get rejected (in which case we'll need to persist it with the
# previous state group)
#

if not state_group_before_event:
state_group_before_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
current_state_ids=state_ids_before_event,
# We make sure that we have a state group assigned to the state.
if entry.state_group is None:
state_ids_before_event = await entry.get_state(
self._state_storage_controller, StateFilter.all()
)
state_group_before_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
current_state_ids=state_ids_before_event,
)
)
)

# Assign the new state group to the cached state entry.
#
# Note that this can race in that we could generate multiple state
# groups for the same state entry, but that is just inefficient
# rather than dangerous.
if entry and entry.state_group is None:
entry.state_group = state_group_before_event
else:
state_group_before_event = entry.state_group
state_ids_before_event = None

#
# now if it's not a state event, we're done
Expand All @@ -301,6 +328,10 @@ async def compute_event_context(
#
# otherwise, we'll need to create a new state group for after the event
#
if state_ids_before_event is None:
state_ids_before_event = await entry.get_state(
self._state_storage_controller, StateFilter.all()
)

key = (event.type, event.state_key)
if key in state_ids_before_event:
Expand Down
4 changes: 3 additions & 1 deletion synapse/storage/controllers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ def __init__(self, hs: "HomeServer", stores: Databases):

self.persistence = None
if stores.persist_events:
self.persistence = EventsPersistenceStorageController(hs, stores)
self.persistence = EventsPersistenceStorageController(
hs, stores, self.state
)
12 changes: 10 additions & 2 deletions synapse/storage/controllers/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@
from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
Expand Down Expand Up @@ -308,7 +310,12 @@ class EventsPersistenceStorageController:
current state and forward extremity changes.
"""

def __init__(self, hs: "HomeServer", stores: Databases):
def __init__(
self,
hs: "HomeServer",
stores: Databases,
state_controller: StateStorageController,
):
# We ultimately want to split out the state store from the main store,
# so we use separate variables here even though they point to the same
# store for now.
Expand All @@ -325,6 +332,7 @@ def __init__(self, hs: "HomeServer", stores: Databases):
self._process_event_persist_queue_task
)
self._state_resolution_handler = hs.get_state_resolution_handler()
self._state_controller = state_controller

async def _process_event_persist_queue_task(
self,
Expand Down Expand Up @@ -504,7 +512,7 @@ async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
state_res_store=StateResolutionStore(self.main_store),
)

return res.state
return await res.get_state(self._state_controller, StateFilter.all())

async def _persist_event_batch(
self, _room_id: str, task: _PersistEventsTask
Expand Down
11 changes: 6 additions & 5 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ async def get_mutual_rooms_between_users(
return shared_room_ids or frozenset()

async def get_joined_users_from_state(
self, room_id: str, state_entry: "_StateCacheEntry"
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> Dict[str, ProfileInfo]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
Expand All @@ -794,7 +794,7 @@ async def get_joined_users_from_state(
assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"):
return await self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry
room_id, state_group, state, context=state_entry
)

@cached(num_args=2, iterable=True, max_entries=100000)
Expand Down Expand Up @@ -998,7 +998,7 @@ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
)

async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
Expand All @@ -1011,14 +1011,15 @@ async def get_joined_hosts(
assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
room_id, state_group, state_entry=state_entry
room_id, state_group, state, state_entry=state_entry
)

@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
self,
room_id: str,
state_group: Union[object, int],
state: StateMap[str],
state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
Expand Down Expand Up @@ -1074,7 +1075,7 @@ async def _get_joined_hosts(
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
joined_users = await self.get_joined_users_from_state(
room_id, state_entry
room_id, state, state_entry
)

cache.hosts_to_joined_users = {}
Expand Down