Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Persist auth/state events at backwards extremities when we fetch them #6526

Merged
merged 7 commits into from
Dec 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/6526.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs.
247 changes: 80 additions & 167 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id
from synapse.util import batch_iter, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
Expand Down Expand Up @@ -238,7 +237,6 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
return None

state = None
auth_chain = []

# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
Expand Down Expand Up @@ -348,7 +346,6 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:

# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
auth_chains = set()
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
Expand All @@ -369,24 +366,14 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
"Requesting state at missing prev_event %s", event_id,
)

room_version = await self.store.get_room_version(room_id)

with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
(
remote_state,
got_auth_chain,
) = await self._get_state_for_room(
(remote_state, _,) = await self._get_state_for_room(
origin, room_id, p, include_event_in_state=True
)

# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
auth_chains.update(got_auth_chain)

remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
Expand All @@ -395,6 +382,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
for x in remote_state:
event_map[x.event_id] = x

room_version = await self.store.get_room_version(room_id)
state_map = await resolve_events_with_store(
room_version,
state_maps,
Expand All @@ -415,7 +403,6 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
event_map.update(evs)

state = [event_map[e] for e in six.itervalues(state_map)]
auth_chain = list(auth_chains)
except Exception:
logger.warning(
"[%s %s] Error attempting to resolve state at missing "
Expand All @@ -431,9 +418,7 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
affected=event_id,
)

await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain
)
await self._process_received_pdu(origin, pdu, state=state)

async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Expand Down Expand Up @@ -632,10 +617,7 @@ async def _get_events_from_store_or_dest(
) -> Dict[str, EventBase]:
"""Fetch events from a remote destination, checking if we already have them.

Args:
destination
room_id
event_ids
Persists any events we don't already have as outliers.

If we fail to fetch any of the events, a warning will be logged, and the event
will be omitted from the result. Likewise, any events which turn out not to
Expand All @@ -655,27 +637,15 @@ async def _get_events_from_store_or_dest(
room_id,
)

room_version = await self.store.get_room_version(room_id)

# XXX 20 requests at once? really?
for batch in batch_iter(missing_events, 20):
deferreds = [
run_in_background(
self.federation_client.get_pdu,
destinations=[destination],
event_id=e_id,
room_version=room_version,
)
for e_id in batch
]

res = await make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
await self._get_events_and_persist(
destination=destination, room_id=room_id, events=missing_events
)

for success, result in res:
if success and result:
fetched_events[result.event_id] = result
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
(await self.store.get_events(missing_events, allow_rejected=True))
)

# check for events which were in the wrong room.
#
Expand Down Expand Up @@ -704,50 +674,26 @@ async def _get_events_from_store_or_dest(

return fetched_events

async def _process_received_pdu(self, origin, event, state, auth_chain):
async def _process_received_pdu(
self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.

Args:
origin: server sending the event

event: event to be persisted

state: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
"""
room_id = event.room_id
event_id = event.event_id

logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)

event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}

seen_ids = await self.store.have_seen_events(event_ids)

if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)

event_infos = []

for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = e.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
event_infos.append(_NewEventInfo(event=e, auth_events=auth))
seen_ids.add(e.event_id)

logger.info(
"[%s %s] persisting newly-received auth/state events %s",
room_id,
event_id,
[e.event.event_id for e in event_infos],
)
await self._handle_new_events(origin, event_infos)

try:
context = await self._handle_new_event(origin, event, state=state)
except AuthError as e:
Expand Down Expand Up @@ -802,8 +748,6 @@ async def backfill(self, dest, room_id, limit, extremities):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")

room_version = await self.store.get_room_version(room_id)

events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
Expand Down Expand Up @@ -832,6 +776,9 @@ async def backfill(self, dest, room_id, limit, extremities):

event_ids = set(e.event_id for e in events)

# build a list of events whose prev_events weren't in the batch.
# (XXX: this will include events whose prev_events we already have; that doesn't
# sound right?)
edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]

logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
Expand Down Expand Up @@ -860,95 +807,11 @@ async def backfill(self, dest, room_id, limit, extremities):
auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()

# Try and fetch any missing auth events from both DB and remote servers.
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)

required_auth.update(
a_id for event in ret_events.values() for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)

if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
missing_auth - failed_to_fetch,
)

results = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.federation_client.get_pdu,
[dest],
event_id,
room_version=room_version,
outlier=True,
timeout=10000,
)
for event_id in missing_auth - failed_to_fetch
],
consumeErrors=True,
)
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a})
required_auth.update(
a_id
for event in results
if event
for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)

failed_to_fetch = missing_auth - set(auth_events)

seen_events = await self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)

# We now have a chunk of events plus associated state and auth chain to
# persist. We do the persistence in two steps:
# 1. Auth events and state get persisted as outliers, plus the
# backward extremities get persisted (as non-outliers).
# 2. The rest of the events in the chunk get persisted one by one, as
# each one depends on the previous event for its state.
#
# The important thing is that events in the chunk get persisted as
# non-outliers, including when those events are also in the state or
# auth chain. Caution must therefore be taken to ensure that they are
# not accidentally marked as outliers.

# Step 1a: persist auth events that *don't* appear in the chunk
ev_infos = []
for a in auth_events.values():
# We only want to persist auth events as outliers that we haven't
# seen and aren't about to persist as part of the backfilled chunk.
if a.event_id in seen_events or a.event_id in event_map:
continue

a.internal_metadata.outlier = True
ev_infos.append(
_NewEventInfo(
event=a,
auth_events={
(
auth_events[a_id].type,
auth_events[a_id].state_key,
): auth_events[a_id]
for a_id in a.auth_event_ids()
if a_id in auth_events
},
)
)

# Step 1b: persist the events in the chunk we fetched state for (i.e.
# the backwards extremities) as non-outliers.
# Step 1: persist the events in the chunk we fetched state for (i.e.
# the backwards extremities), with custom auth events and state
for e_id in events_to_state:
# For paranoia we ensure that these events are marked as
# non-outliers
Expand Down Expand Up @@ -1190,6 +1053,56 @@ async def try_backfill(domains):

return False

async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str]
):
"""Fetch the given events from a server, and persist them as outliers.

Logs a warning if we can't find the given event.
"""

room_version = await self.store.get_room_version(room_id)

event_infos = []

async def get_event(event_id: str):
with nested_logging_context(event_id):
try:
event = await self.federation_client.get_pdu(
[destination], event_id, room_version, outlier=True,
)
if event is None:
logger.warning(
"Server %s didn't return event %s", destination, event_id,
)
return

# recursively fetch the auth events for this event
auth_events = await self._get_events_from_store_or_dest(
destination, room_id, event.auth_event_ids()
)
auth = {}
for auth_event_id in event.auth_event_ids():
ae = auth_events.get(auth_event_id)
if ae:
auth[(ae.type, ae.state_key)] = ae

event_infos.append(_NewEventInfo(event, None, auth))

except Exception as e:
logger.warning(
"Error fetching missing state/auth event %s: %s %s",
event_id,
type(e),
e,
)

await concurrently_execute(get_event, events, 5)

await self._handle_new_events(
destination, event_infos,
)

def _sanity_check_event(self, ev):
"""
Do some early sanity checks of a received event
Expand Down
4 changes: 2 additions & 2 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def concurrently_execute(func, args, limit):

Args:
func (func): Function to execute, should return a deferred or coroutine.
args (list): List of arguments to pass to func, each invocation of func
gets a signle argument.
args (Iterable): List of arguments to pass to func, each invocation of func
gets a single argument.
limit (int): Maximum number of conccurent executions.

Returns:
Expand Down