1
1
# Copyright 2014-2016 OpenMarket Ltd
2
+ # Copyright 2022 The Matrix.org Foundation C.I.C.
2
3
#
3
4
# Licensed under the Apache License, Version 2.0 (the "License");
4
5
# you may not use this file except in compliance with the License.
15
16
from typing import (
16
17
TYPE_CHECKING ,
17
18
Awaitable ,
19
+ Callable ,
18
20
Collection ,
19
21
Dict ,
20
22
Iterable ,
@@ -532,6 +534,44 @@ def approx_difference(self, other: "StateFilter") -> "StateFilter":
532
534
new_all , new_excludes , new_wildcards , new_concrete_keys
533
535
)
534
536
537
+ def must_await_full_state (self , is_mine_id : Callable [[str ], bool ]) -> bool :
538
+ """Check if we need to wait for full state to complete to calculate this state
539
+
540
+ If we have a state filter which is completely satisfied even with partial
541
+ state, then we don't need to await_full_state before we can return it.
542
+
543
+ Args:
544
+ is_mine_id: a callable which confirms if a given state_key matches a mxid
545
+ of a local user
546
+ """
547
+
548
+ # TODO(faster_joins): it's not entirely clear that this is safe. In particular,
549
+ # there may be circumstances in which we return a piece of state that, once we
550
+ # resync the state, we discover is invalid. For example: if it turns out that
551
+ # the sender of a piece of state wasn't actually in the room, then clearly that
552
+ # state shouldn't have been returned.
553
+ # We should at least add some tests around this to see what happens.
554
+
555
+ # if we haven't requested membership events, then it depends on the value of
556
+ # 'include_others'
557
+ if EventTypes .Member not in self .types :
558
+ return self .include_others
559
+
560
+ # if we're looking for *all* membership events, then we have to wait
561
+ member_state_keys = self .types [EventTypes .Member ]
562
+ if member_state_keys is None :
563
+ return True
564
+
565
+ # otherwise, consider whose membership we are looking for. If it's entirely
566
+ # local users, then we don't need to wait.
567
+ for state_key in member_state_keys :
568
+ if not is_mine_id (state_key ):
569
+ # remote user
570
+ return True
571
+
572
+ # local users only
573
+ return False
574
+
535
575
536
576
_ALL_STATE_FILTER = StateFilter (types = frozendict (), include_others = True )
537
577
_ALL_NON_MEMBER_STATE_FILTER = StateFilter (
@@ -544,6 +584,7 @@ class StateGroupStorage:
544
584
"""High level interface to fetching state for event."""
545
585
546
586
def __init__ (self , hs : "HomeServer" , stores : "Databases" ):
587
+ self ._is_mine_id = hs .is_mine_id
547
588
self .stores = stores
548
589
self ._partial_state_events_tracker = PartialStateEventsTracker (stores .main )
549
590
@@ -675,7 +716,13 @@ async def get_state_for_events(
675
716
RuntimeError if we don't have a state group for one or more of the events
676
717
(ie they are outliers or unknown)
677
718
"""
678
- event_to_groups = await self .get_state_group_for_events (event_ids )
719
+ await_full_state = True
720
+ if state_filter and not state_filter .must_await_full_state (self ._is_mine_id ):
721
+ await_full_state = False
722
+
723
+ event_to_groups = await self .get_state_group_for_events (
724
+ event_ids , await_full_state = await_full_state
725
+ )
679
726
680
727
groups = set (event_to_groups .values ())
681
728
group_to_state = await self .stores .state ._get_state_for_groups (
@@ -699,7 +746,9 @@ async def get_state_for_events(
699
746
return {event : event_to_state [event ] for event in event_ids }
700
747
701
748
async def get_state_ids_for_events (
702
- self , event_ids : Collection [str ], state_filter : Optional [StateFilter ] = None
749
+ self ,
750
+ event_ids : Collection [str ],
751
+ state_filter : Optional [StateFilter ] = None ,
703
752
) -> Dict [str , StateMap [str ]]:
704
753
"""
705
754
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -716,7 +765,13 @@ async def get_state_ids_for_events(
716
765
RuntimeError if we don't have a state group for one or more of the events
717
766
(ie they are outliers or unknown)
718
767
"""
719
- event_to_groups = await self .get_state_group_for_events (event_ids )
768
+ await_full_state = True
769
+ if state_filter and not state_filter .must_await_full_state (self ._is_mine_id ):
770
+ await_full_state = False
771
+
772
+ event_to_groups = await self .get_state_group_for_events (
773
+ event_ids , await_full_state = await_full_state
774
+ )
720
775
721
776
groups = set (event_to_groups .values ())
722
777
group_to_state = await self .stores .state ._get_state_for_groups (
@@ -802,7 +857,7 @@ async def get_state_group_for_events(
802
857
Args:
803
858
event_ids: events to get state groups for
804
859
await_full_state: if true, will block if we do not yet have complete
805
- state at this event .
860
+ state at these events .
806
861
"""
807
862
if await_full_state :
808
863
await self ._partial_state_events_tracker .await_full_state (event_ids )
0 commit comments