diff --git a/changelog.d/10645.misc b/changelog.d/10645.misc new file mode 100644 index 000000000000..ac19263cd861 --- /dev/null +++ b/changelog.d/10645.misc @@ -0,0 +1 @@ +Make `backfill` and `get_missing_events` use the same codepath. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 246df43501bc..6fa2fc8f5284 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -65,6 +65,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator +from synapse.federation.federation_client import InvalidResponseError from synapse.handlers._base import BaseHandler from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( @@ -116,10 +117,6 @@ class _NewEventInfo: Attributes: event: the received event - state: the state at that event, according to /state_ids from a remote - homeserver. Only populated for backfilled events which are going to be a - new backwards extremity. - claimed_auth_event_map: a map of (type, state_key) => event for the event's claimed auth_events. @@ -134,7 +131,6 @@ class _NewEventInfo: """ event: EventBase - state: Optional[Sequence[EventBase]] claimed_auth_event_map: StateMap[EventBase] @@ -443,113 +439,7 @@ async def _get_missing_events_for_pdu( return logger.info("Got %d prev_events", len(missing_events)) - await self._process_pulled_events(origin, missing_events) - - async def _get_state_for_room( - self, - destination: str, - room_id: str, - event_id: str, - ) -> List[EventBase]: - """Requests all of the room state at a given event from a remote - homeserver. - - Will also fetch any missing events reported in the `auth_chain_ids` - section of `/state_ids`. - - Args: - destination: The remote homeserver to query for the state. - room_id: The id of the room we're interested in. - event_id: The id of the event we want the state at. - - Returns: - A list of events in the state, not including the event itself. - """ - ( - state_event_ids, - auth_event_ids, - ) = await self.federation_client.get_room_state_ids( - destination, room_id, event_id=event_id - ) - - # Fetch the state events from the DB, and check we have the auth events. - event_map = await self.store.get_events(state_event_ids, allow_rejected=True) - auth_events_in_store = await self.store.have_seen_events( - room_id, auth_event_ids - ) - - # Check for missing events. We handle state and auth event seperately, - # as we want to pull the state from the DB, but we don't for the auth - # events. (Note: we likely won't use the majority of the auth chain, and - # it can be *huge* for large rooms, so it's worth ensuring that we don't - # unnecessarily pull it from the DB). - missing_state_events = set(state_event_ids) - set(event_map) - missing_auth_events = set(auth_event_ids) - set(auth_events_in_store) - if missing_state_events or missing_auth_events: - await self._get_events_and_persist( - destination=destination, - room_id=room_id, - events=missing_state_events | missing_auth_events, - ) - - if missing_state_events: - new_events = await self.store.get_events( - missing_state_events, allow_rejected=True - ) - event_map.update(new_events) - - missing_state_events.difference_update(new_events) - - if missing_state_events: - logger.warning( - "Failed to fetch missing state events for %s %s", - event_id, - missing_state_events, - ) - - if missing_auth_events: - auth_events_in_store = await self.store.have_seen_events( - room_id, missing_auth_events - ) - missing_auth_events.difference_update(auth_events_in_store) - - if missing_auth_events: - logger.warning( - "Failed to fetch missing auth events for %s %s", - event_id, - missing_auth_events, - ) - - remote_state = list(event_map.values()) - - # check for events which were in the wrong room. - # - # this can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B - - bad_events = [ - (event.event_id, event.room_id) - for event in remote_state - if event.room_id != room_id - ] - - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned auth/state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) - - if bad_events: - remote_state = [e for e in remote_state if e.room_id == room_id] - - return remote_state + await self._process_pulled_events(origin, missing_events, backfilled=False) async def _get_state_after_missing_prev_event( self, @@ -567,10 +457,6 @@ async def _get_state_after_missing_prev_event( Returns: A list of events in the state, including the event itself """ - # TODO: This function is basically the same as _get_state_for_room. Can - # we make backfill() use it, rather than having two code paths? I think the - # only difference is that backfill() persists the prev events separately. - ( state_event_ids, auth_event_ids, @@ -681,6 +567,7 @@ async def _process_received_pdu( origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + backfilled: bool = False, ) -> None: """Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. @@ -693,6 +580,9 @@ async def _process_received_pdu( state: Normally None, but if we are handling a gap in the graph (ie, we are missing one or more prev_events), the resolved state at the event + + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) """ logger.debug("Processing event: %s", event) @@ -700,10 +590,15 @@ async def _process_received_pdu( context = await self.state_handler.compute_event_context( event, old_state=state ) - await self._auth_and_persist_event(origin, event, context, state=state) + await self._auth_and_persist_event( + origin, event, context, state=state, backfilled=backfilled + ) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) + if backfilled: + return + # For encrypted messages we check that we know about the sending device, # if we don't then we mark the device cache for that user as stale. if event.type == EventTypes.Encrypted: @@ -868,7 +763,7 @@ async def _resync_device(self, sender: str) -> None: @log_function async def backfill( self, dest: str, room_id: str, limit: int, extremities: List[str] - ) -> List[EventBase]: + ) -> None: """Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side @@ -878,6 +773,9 @@ async def backfill( sanity-checking on them. If any of the backfilled events are invalid, this method throws a SynapseError. + We might also raise an InvalidResponseError if the response from the remote + server is just bogus. + TODO: make this more useful to distinguish failures of the remote server from invalid events (there is probably no point in trying to re-fetch invalid events from every other HS in the room.) @@ -890,111 +788,18 @@ async def backfill( ) if not events: - return [] - - # ideally we'd sanity check the events here for excess prev_events etc, - # but it's hard to reject events at this point without completely - # breaking backfill in the same way that it is currently broken by - # events whose signature we cannot verify (#3121). - # - # So for now we accept the events anyway. #3124 tracks this. - # - # for ev in events: - # self._sanity_check_event(ev) - - # Don't bother processing events we already have. - seen_events = await self.store.have_events_in_timeline( - {e.event_id for e in events} - ) - - events = [e for e in events if e.event_id not in seen_events] - - if not events: - return [] - - event_map = {e.event_id: e for e in events} - - event_ids = {e.event_id for e in events} - - # build a list of events whose prev_events weren't in the batch. - # (XXX: this will include events whose prev_events we already have; that doesn't - # sound right?) - edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] - - logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) - - # For each edge get the current state. - - state_events = {} - events_to_state = {} - for e_id in edges: - state = await self._get_state_for_room( - destination=dest, - room_id=room_id, - event_id=e_id, - ) - state_events.update({s.event_id: s for s in state}) - events_to_state[e_id] = state + return - required_auth = { - a_id - for event in events + list(state_events.values()) - for a_id in event.auth_event_ids() - } - auth_events = await self.store.get_events(required_auth, allow_rejected=True) - auth_events.update( - {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} - ) - - ev_infos = [] - - # Step 1: persist the events in the chunk we fetched state for (i.e. - # the backwards extremities), with custom auth events and state - for e_id in events_to_state: - # For paranoia we ensure that these events are marked as - # non-outliers - ev = event_map[e_id] - assert not ev.internal_metadata.is_outlier() - - ev_infos.append( - _NewEventInfo( - event=ev, - state=events_to_state[e_id], - claimed_auth_event_map={ - ( - auth_events[a_id].type, - auth_events[a_id].state_key, - ): auth_events[a_id] - for a_id in ev.auth_event_ids() - if a_id in auth_events - }, + # if there are any events in the wrong room, the remote server is buggy and + # should not be trusted. + for ev in events: + if ev.room_id != room_id: + raise InvalidResponseError( + f"Remote server {dest} returned event {ev.event_id} which is in " + f"room {ev.room_id}, when we were backfilling in {room_id}" ) - ) - - if ev_infos: - await self._auth_and_persist_events( - dest, room_id, ev_infos, backfilled=True - ) - - # Step 2: Persist the rest of the events in the chunk one by one - events.sort(key=lambda e: e.depth) - - for event in events: - if event in events_to_state: - continue - - # For paranoia we ensure that these events are marked as - # non-outliers - assert not event.internal_metadata.is_outlier() - - context = await self.state_handler.compute_event_context(event) - - # We store these one at a time since each event depends on the - # previous to work out the state. - # TODO: We can probably do something more clever here. - await self._auth_and_persist_event(dest, event, context, backfilled=True) - return events + await self._process_pulled_events(dest, events, backfilled=True) async def maybe_backfill( self, room_id: str, current_depth: int, limit: int @@ -1197,7 +1002,7 @@ async def try_backfill(domains: List[str]) -> bool: # appropriate stuff. # TODO: We can probably do something more intelligent here. return True - except SynapseError as e: + except (SynapseError, InvalidResponseError) as e: logger.info("Failed to backfill from %s because %s", dom, e) continue except HttpResponseException as e: @@ -1351,7 +1156,7 @@ async def get_event(event_id: str): else: logger.info("Missing auth event %s", auth_event_id) - event_infos.append(_NewEventInfo(event, None, auth)) + event_infos.append(_NewEventInfo(event, auth)) if event_infos: await self._auth_and_persist_events( @@ -1361,7 +1166,7 @@ async def get_event(event_id: str): ) async def _process_pulled_events( - self, origin: str, events: Iterable[EventBase] + self, origin: str, events: Iterable[EventBase], backfilled: bool ) -> None: """Process a batch of events we have pulled from a remote server @@ -1373,6 +1178,8 @@ async def _process_pulled_events( Params: origin: The server we received these events from events: The received events. + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) """ # We want to sort these by depth so we process them and @@ -1381,9 +1188,11 @@ async def _process_pulled_events( for ev in sorted_events: with nested_logging_context(ev.event_id): - await self._process_pulled_event(origin, ev) + await self._process_pulled_event(origin, ev, backfilled=backfilled) - async def _process_pulled_event(self, origin: str, event: EventBase) -> None: + async def _process_pulled_event( + self, origin: str, event: EventBase, backfilled: bool + ) -> None: """Process a single event that we have pulled from a remote server Pulls in any events required to auth the event, persists the received event, @@ -1400,6 +1209,8 @@ async def _process_pulled_event(self, origin: str, event: EventBase) -> None: Params: origin: The server we received this event from events: The received event + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) """ logger.info("Processing pulled event %s", event) @@ -1428,7 +1239,9 @@ async def _process_pulled_event(self, origin: str, event: EventBase) -> None: try: state = await self._resolve_state_at_missing_prevs(origin, event) - await self._process_received_pdu(origin, event, state=state) + await self._process_received_pdu( + origin, event, state=state, backfilled=backfilled + ) except FederationError as e: if e.code == 403: logger.warning("Pulled event %s failed history check.", event_id) @@ -2451,7 +2264,6 @@ async def _auth_and_persist_events( origin: str, room_id: str, event_infos: Collection[_NewEventInfo], - backfilled: bool = False, ) -> None: """Creates the appropriate contexts and persists events. The events should not depend on one another, e.g. this should be used to persist @@ -2467,16 +2279,12 @@ async def _auth_and_persist_events( async def prep(ev_info: _NewEventInfo): event = ev_info.event with nested_logging_context(suffix=event.event_id): - res = await self.state_handler.compute_event_context( - event, old_state=ev_info.state - ) + res = await self.state_handler.compute_event_context(event) res = await self._check_event_auth( origin, event, res, - state=ev_info.state, claimed_auth_event_map=ev_info.claimed_auth_event_map, - backfilled=backfilled, ) return res @@ -2493,7 +2301,6 @@ async def prep(ev_info: _NewEventInfo): (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) ], - backfilled=backfilled, ) async def _persist_auth_tree( diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 664c65dac5a6..bccff5e5b95c 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -295,6 +295,7 @@ def _purge_history_txn( self._invalidate_cache_and_stream( txn, self.have_seen_event, (room_id, event_id) ) + self._invalidate_get_event_cache(event_id) logger.info("[purge] done")