@@ -552,17 +552,20 @@ async def _get_state_for_room(
552
552
destination : str ,
553
553
room_id : str ,
554
554
event_id : str ,
555
- ) -> Tuple [List [EventBase ], List [EventBase ]]:
556
- """Requests all of the room state at a given event from a remote homeserver.
555
+ ) -> List [EventBase ]:
556
+ """Requests all of the room state at a given event from a remote
557
+ homeserver.
558
+
559
+ Will also fetch any missing events reported in the `auth_chain_ids`
560
+ section of `/state_ids`.
557
561
558
562
Args:
559
563
destination: The remote homeserver to query for the state.
560
564
room_id: The id of the room we're interested in.
561
565
event_id: The id of the event we want the state at.
562
566
563
567
Returns:
564
- A list of events in the state, not including the event itself, and
565
- a list of events in the auth chain for the given event.
568
+ A list of events in the state, not including the event itself.
566
569
"""
567
570
(
568
571
state_event_ids ,
@@ -571,77 +574,62 @@ async def _get_state_for_room(
571
574
destination , room_id , event_id = event_id
572
575
)
573
576
574
- desired_events = set (state_event_ids + auth_event_ids )
575
-
576
- event_map = await self ._get_events_from_store_or_dest (
577
- destination , room_id , desired_events
578
- )
577
+ # Fetch the state events from the DB, and check we have the auth events.
578
+ event_map = await self .store .get_events (state_event_ids , allow_rejected = True )
579
+ auth_events_in_store = await self .store .have_seen_events (auth_event_ids )
579
580
580
- failed_to_fetch = desired_events - event_map .keys ()
581
- if failed_to_fetch :
582
- logger .warning (
583
- "Failed to fetch missing state/auth events for %s %s" ,
584
- event_id ,
585
- failed_to_fetch ,
581
+ # Check for missing events. We handle state and auth event seperately,
582
+ # as we want to pull the state from the DB, but we don't for the auth
583
+ # events. (Note: we likely won't use the majority of the auth chain, and
584
+ # it can be *huge* for large rooms, so it's worth ensuring that we don't
585
+ # unnecessarily pull it from the DB).
586
+ missing_state_events = set (state_event_ids ) - set (event_map )
587
+ missing_auth_events = set (auth_event_ids ) - set (auth_events_in_store )
588
+ if missing_state_events or missing_auth_events :
589
+ await self ._get_events_and_persist (
590
+ destination = destination ,
591
+ room_id = room_id ,
592
+ events = missing_state_events | missing_auth_events ,
586
593
)
587
594
588
- remote_state = [
589
- event_map [e_id ] for e_id in state_event_ids if e_id in event_map
590
- ]
591
-
592
- auth_chain = [event_map [e_id ] for e_id in auth_event_ids if e_id in event_map ]
593
- auth_chain .sort (key = lambda e : e .depth )
594
-
595
- return remote_state , auth_chain
596
-
597
- async def _get_events_from_store_or_dest (
598
- self , destination : str , room_id : str , event_ids : Iterable [str ]
599
- ) -> Dict [str , EventBase ]:
600
- """Fetch events from a remote destination, checking if we already have them.
601
-
602
- Persists any events we don't already have as outliers.
603
-
604
- If we fail to fetch any of the events, a warning will be logged, and the event
605
- will be omitted from the result. Likewise, any events which turn out not to
606
- be in the given room.
595
+ if missing_state_events :
596
+ new_events = await self .store .get_events (
597
+ missing_state_events , allow_rejected = True
598
+ )
599
+ event_map .update (new_events )
607
600
608
- This function *does not* automatically get missing auth events of the
609
- newly fetched events. Callers must include the full auth chain of
610
- of the missing events in the `event_ids` argument, to ensure that any
611
- missing auth events are correctly fetched.
601
+ missing_state_events .difference_update (new_events )
612
602
613
- Returns :
614
- map from event_id to event
615
- """
616
- fetched_events = await self . store . get_events ( event_ids , allow_rejected = True )
617
-
618
- missing_events = set ( event_ids ) - fetched_events . keys ( )
603
+ if missing_state_events :
604
+ logger . warning (
605
+ "Failed to fetch missing state events for %s %s" ,
606
+ event_id ,
607
+ missing_state_events ,
608
+ )
619
609
620
- if missing_events :
621
- logger .debug (
622
- "Fetching unknown state/auth events %s for room %s" ,
623
- missing_events ,
624
- room_id ,
625
- )
610
+ if missing_auth_events :
611
+ auth_events_in_store = await self .store .have_seen_events (
612
+ missing_auth_events
613
+ )
614
+ missing_auth_events .difference_update (auth_events_in_store )
626
615
627
- await self ._get_events_and_persist (
628
- destination = destination , room_id = room_id , events = missing_events
629
- )
616
+ if missing_auth_events :
617
+ logger .warning (
618
+ "Failed to fetch missing auth events for %s %s" ,
619
+ event_id ,
620
+ missing_auth_events ,
621
+ )
630
622
631
- # we need to make sure we re-load from the database to get the rejected
632
- # state correct.
633
- fetched_events .update (
634
- (await self .store .get_events (missing_events , allow_rejected = True ))
635
- )
623
+ remote_state = list (event_map .values ())
636
624
637
625
# check for events which were in the wrong room.
638
626
#
639
627
# this can happen if a remote server claims that the state or
640
628
# auth_events at an event in room A are actually events in room B
641
629
642
630
bad_events = [
643
- (event_id , event .room_id )
644
- for event_id , event in fetched_events . items ()
631
+ (event . event_id , event .room_id )
632
+ for event in remote_state
645
633
if event .room_id != room_id
646
634
]
647
635
@@ -658,9 +646,10 @@ async def _get_events_from_store_or_dest(
658
646
room_id ,
659
647
)
660
648
661
- del fetched_events [bad_event_id ]
649
+ if bad_events :
650
+ remote_state = [e for e in remote_state if e .room_id == room_id ]
662
651
663
- return fetched_events
652
+ return remote_state
664
653
665
654
async def _get_state_after_missing_prev_event (
666
655
self ,
@@ -963,27 +952,23 @@ async def backfill(
963
952
964
953
# For each edge get the current state.
965
954
966
- auth_events = {}
967
955
state_events = {}
968
956
events_to_state = {}
969
957
for e_id in edges :
970
- state , auth = await self ._get_state_for_room (
958
+ state = await self ._get_state_for_room (
971
959
destination = dest ,
972
960
room_id = room_id ,
973
961
event_id = e_id ,
974
962
)
975
- auth_events .update ({a .event_id : a for a in auth })
976
- auth_events .update ({s .event_id : s for s in state })
977
963
state_events .update ({s .event_id : s for s in state })
978
964
events_to_state [e_id ] = state
979
965
980
966
required_auth = {
981
967
a_id
982
- for event in events
983
- + list (state_events .values ())
984
- + list (auth_events .values ())
968
+ for event in events + list (state_events .values ())
985
969
for a_id in event .auth_event_ids ()
986
970
}
971
+ auth_events = await self .store .get_events (required_auth , allow_rejected = True )
987
972
auth_events .update (
988
973
{e_id : event_map [e_id ] for e_id in required_auth if e_id in event_map }
989
974
)
0 commit comments