Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor _resolve_state_at_missing_prevs to return an EventContext #13404

122 changes: 42 additions & 80 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,19 +278,15 @@ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
)

try:
await self._process_received_pdu(
origin, pdu, state_ids=None, partial_state=None
)
await self._process_received_pdu(origin, pdu, context=None)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the PDU.
# Try once more, with full state this time.
logger.info(
"Room %s was un-partial stated while processing the PDU, trying again.",
room_id,
)
await self._process_received_pdu(
origin, pdu, state_ids=None, partial_state=None
)
await self._process_received_pdu(origin, pdu, context=None)

async def on_send_membership_event(
self, origin: str, event: EventBase
Expand Down Expand Up @@ -320,6 +316,7 @@ async def on_send_membership_event(
The event and context of the event after inserting it into the room graph.

Raises:
RuntimeError if any prev_events are missing
SynapseError if the event is not accepted into the room
PartialStateConflictError if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should
Expand Down Expand Up @@ -380,7 +377,7 @@ async def on_send_membership_event(
# need to.
await self._event_creation_handler.cache_joined_hosts_for_event(event, context)

await self._check_for_soft_fail(event, None, origin=origin)
await self._check_for_soft_fail(event, context=context, origin=origin)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have all the prev events here, otherwise we would have raised a RuntimeError when computing the event context. Thus the soft fail check will behave the same as before.

await self._run_push_actions_and_persist_event(event, context)
return event, context

Expand Down Expand Up @@ -538,36 +535,10 @@ async def update_state_for_partial_state_event(
#
# This is the same operation as we do when we receive a regular event
# over federation.
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
context = await self._compute_event_context_with_missing_prevs(
destination, event
)

# There are three possible cases for (state_ids, partial_state):
# * `state_ids` and `partial_state` are both `None` if we had all the
# prev_events. The prev_events may or may not have partial state and
# we won't know until we compute the event context.
# * `state_ids` is not `None` and `partial_state` is `False` if we were
# missing some prev_events (but we have full state for any we did
# have). We calculated the full state after the prev_events.
# * `state_ids` is not `None` and `partial_state` is `True` if we were
# missing some, but not all, prev_events. At least one of the
# prev_events we did have had partial state, so we calculated a partial
# state after the prev_events.

context = None
if state_ids is not None and partial_state:
# the state after the prev events is still partial. We can't de-partial
# state the event, so don't bother building the event context.
pass
else:
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
partial_state=partial_state,
)

if context is None or context.partial_state:
Comment on lines -544 to -570
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We lose the optimization here in exchange for simpler code.

if context.partial_state:
# this can happen if some or all of the event's prev_events still have
# partial state. We were careful to only pick events from the db without
# partial-state prev events, so that implies that a prev event has
Expand Down Expand Up @@ -833,26 +804,25 @@ async def _process_pulled_event(

try:
try:
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
context = await self._compute_event_context_with_missing_prevs(
origin, event
)
await self._process_received_pdu(
origin,
event,
state_ids=state_ids,
partial_state=partial_state,
context=context,
backfilled=backfilled,
)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the event.
# Try once more, with full state this time.
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
context = await self._compute_event_context_with_missing_prevs(
origin, event
)

# We ought to have full state now, barring some unlikely race where we left and
# rejoned the room in the background.
if state_ids is not None and partial_state:
if context.partial_state:
raise AssertionError(
f"Event {event.event_id} still has a partial resolved state "
f"after room {event.room_id} was un-partial stated"
Expand All @@ -861,8 +831,7 @@ async def _process_pulled_event(
await self._process_received_pdu(
origin,
event,
state_ids=state_ids,
partial_state=partial_state,
context=context,
backfilled=backfilled,
)
except FederationError as e:
Expand All @@ -871,15 +840,18 @@ async def _process_pulled_event(
else:
raise

async def _resolve_state_at_missing_prevs(
async def _compute_event_context_with_missing_prevs(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given we don't know if we have missing prevs, this is a bit of a confusing name (notwithstanding the old name being just as confusing).

Unfortunately I can't think of a better name than _compute_event_context_with_maybe_missing_prevs

self, dest: str, event: EventBase
) -> Tuple[Optional[StateMap[str]], Optional[bool]]:
"""Calculate the state at an event with missing prev_events.
) -> EventContext:
"""Build an EventContext structure for a non-outlier event whose prev_events may
be missing.

This is used when we have pulled a batch of events from a remote server, and
still don't have all the prev_events.
This is used when we have pulled a batch of events from a remote server, and may
not have all the prev_events.

If we already have all the prev_events for `event`, this method does nothing.
To build an EventContext, we need to calculate the state at the event. If we
already have all the prev_events for `event`, we can simply use the state at the
prev_events to calculate the state at `event`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/state at/state before/


Otherwise, the missing prevs become new backwards extremities, and we fall back
to asking the remote server for the state after each missing `prev_event`,
Expand All @@ -900,10 +872,7 @@ async def _resolve_state_at_missing_prevs(
event: an event to check for missing prevs.

Returns:
if we already had all the prev events, `None, None`. Otherwise, returns a
tuple containing:
* the event ids of the state at `event`.
* a boolean indicating whether the state may be partial.
The event context.

Raises:
FederationError if we fail to get the state from the remote server after any
Expand All @@ -917,7 +886,7 @@ async def _resolve_state_at_missing_prevs(
missing_prevs = prevs - seen

if not missing_prevs:
return None, None
return await self._state_handler.compute_event_context(event)

logger.info(
"Event %s is missing prev_events %s: calculating state for a "
Expand Down Expand Up @@ -983,7 +952,9 @@ async def _resolve_state_at_missing_prevs(
"We can't get valid state history.",
affected=event_id,
)
return state_map, partial_state
return await self._state_handler.compute_event_context(
event, state_ids_before_event=state_map, partial_state=partial_state
)

async def _get_state_ids_after_missing_prev_event(
self,
Expand Down Expand Up @@ -1152,8 +1123,7 @@ async def _process_received_pdu(
self,
origin: str,
event: EventBase,
state_ids: Optional[StateMap[str]],
partial_state: Optional[bool],
context: Optional[EventContext],
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
Expand All @@ -1175,32 +1145,22 @@ async def _process_received_pdu(

event: event to be persisted

state_ids: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event

partial_state:
`True` if `state_ids` is partial and omits non-critical membership
events.
`False` if `state_ids` is the full state.
`None` if `state_ids` is not provided. In this case, the flag will be
calculated based on `event`'s prev events.
context: The `EventContext` to persist the event with.

backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)

PartialStateConflictError: if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should retry
exactly once in this case.
exactly once in this case. If a `context` was provided, it should be
recomputed.
"""
logger.debug("Processing event: %s", event)
assert not event.internal_metadata.outlier

context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
partial_state=partial_state,
)
if context is None:
context = await self._state_handler.compute_event_context(event)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it would be cleaner to make context mandatory and add a couple of calls to compute_event_context to on_receive_pdu.


try:
await self._check_event_auth(origin, event, context)
except AuthError as e:
Expand All @@ -1212,7 +1172,7 @@ async def _process_received_pdu(
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
await self._check_for_soft_fail(event, state_ids, origin=origin)
await self._check_for_soft_fail(event, context=context, origin=origin)

await self._run_push_actions_and_persist_event(event, context, backfilled)

Expand Down Expand Up @@ -1773,7 +1733,7 @@ async def _maybe_kick_guest_users(self, event: EventBase) -> None:
async def _check_for_soft_fail(
self,
event: EventBase,
state_ids: Optional[StateMap[str]],
context: EventContext,
origin: str,
) -> None:
Comment on lines 1730 to 1735
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only makes sense to call _check_for_soft_fail prior to persisting an event and EventContexts are intended to hold information relevant to persisting an event. So we aren't losing much by accepting an EventContext instead of a StateMap.

"""Checks if we should soft fail the event; if so, marks the event as
Expand All @@ -1784,7 +1744,7 @@ async def _check_for_soft_fail(

Args:
event
state_ids: The state at the event if we don't have all the event's prev events
context: The `EventContext` which we are about to persist the event with.
origin: The host the event originates from.
"""
if await self._store.is_partial_state_room(event.room_id):
Expand All @@ -1810,11 +1770,12 @@ async def _check_for_soft_fail(
auth_types = auth_types_for_event(room_version_obj, event)

# Calculate the "current state".
if state_ids is not None:
# If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for
# a while and have an incorrect view of the current state,
seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
has_missing_prevs = bool(prev_event_ids - seen_event_ids)
if has_missing_prevs:
# We don't have all the prev events, and have a gap in the graph.
squahtx marked this conversation as resolved.
Show resolved Hide resolved
# In this case we want to be a little careful as we might have been
# down for a while and have an incorrect view of the current state,
# however we still want to do checks as gaps are easy to
# maliciously manufacture.
#
Expand All @@ -1827,6 +1788,7 @@ async def _check_for_soft_fail(
event.room_id, extrem_ids
)
state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_ids = await context.get_prev_state_ids()
state_sets.append(state_ids)
current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store(
Expand Down
8 changes: 8 additions & 0 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ async def compute_event_context(
flag will be calculated based on `event`'s prev events.
Returns:
The event context.

Raises:
RuntimeError if `state_ids_before_event` is not provided and one or more
prev events are missing or outliers.
"""

assert not event.internal_metadata.is_outlier()
Expand Down Expand Up @@ -432,6 +436,10 @@ async def resolve_state_groups_for_events(

Returns:
The resolved state

Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie. they are outliers or unknown)
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)

Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ async def get_state_group_for_events(
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.

Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie. they are outliers or unknown)
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
Expand Down
15 changes: 11 additions & 4 deletions tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,23 @@ def test_backfill_with_many_backward_extremities(self) -> None:

# we poke this directly into _process_received_pdu, to avoid the
# federation handler wanting to backfill the fake event.
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME,
state_handler = self.hs.get_state_handler()
context = self.get_success(
state_handler.compute_event_context(
event,
state_ids={
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in current_state
},
partial_state=False,
)
)
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME,
event,
context=context,
)
)

# we should now have 8 backwards extremities.
backwards_extremities = self.get_success(
Expand Down