From f03b3056292871250c1a57e060daecdf6a9311c4 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 7 Dec 2022 18:35:51 +0000 Subject: [PATCH 01/13] Declare new config --- .../configuration/config_documentation.md | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index dc5e5ac59795..4d32902fea21 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2501,32 +2501,53 @@ Config settings related to the client/server API --- ### `room_prejoin_state` -Controls for the state that is shared with users who receive an invite -to a room. By default, the following state event types are shared with users who -receive invites to the room: -- m.room.join_rules -- m.room.canonical_alias -- m.room.avatar -- m.room.encryption -- m.room.name -- m.room.create -- m.room.topic +This setting controls the state that is shared with users upon receiving an +invite to a room, or in reply to a knock on a room. By default, the following +state events are shared with users: + +- `m.room.join_rules` +- `m.room.canonical_alias` +- `m.room.avatar` +- `m.room.encryption` +- `m.room.name` +- `m.room.create` +- `m.room.topic` To change the default behavior, use the following sub-options: -* `disable_default_event_types`: set to true to disable the above defaults. If this - is enabled, only the event types listed in `additional_event_types` are shared. - Defaults to false. -* `additional_event_types`: Additional state event types to share with users when they are invited - to a room. By default, this list is empty (so only the default event types are shared). +* `disable_default_event_types`: boolean. Set to `true` to disable the above + defaults. If this is enabled, only the event types listed in + `additional_event_types` are shared. Defaults to `false`. +* `additional_event_types`: A list of additional state events to include in the + events to be shared. By default, this list is empty (so only the default event + types are shared). + + Each entry in this list should be either a single string or a list of two + strings. + * A standalone string `t` represents all events with type `t` (i.e. + with no restrictions on state keys). + * A pair of strings `[t, s]` represents a single event with type `t` and + state key `s`. The same type can appear in two entries with different state + keys: in this situation, both state keys are included in prejoin state. Example configuration: ```yaml room_prejoin_state: - disable_default_event_types: true + disable_default_event_types: false additional_event_types: - - org.example.custom.event.type - - m.room.join_rules + # Share all events of type `org.example.custom.event.typeA` + - org.example.custom.event.typeA + # Share only events of type `org.example.custom.event.typeB` whose + # state_key is "foo" + - ["org.example.custom.event.typeB", "foo"] + # Share only events of type `org.example.custom.event.typeC` whose + # state_key is "bar" or "baz" + - ["org.example.custom.event.typeC", "bar"] + - ["org.example.custom.event.typeC", "baz"] ``` + +*Changed in Synapse 1.74:* admins can filter the events in prejoin state based +on their state key. + --- ### `track_puppeted_user_ips` From c801ce78cda15670d67bb6423ecd28e2653c9680 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 7 Dec 2022 19:51:52 +0000 Subject: [PATCH 02/13] Parse new config --- mypy.ini | 3 + synapse/config/_util.py | 3 + synapse/config/api.py | 104 +++++++++++++++++++++------ tests/config/test_api.py | 152 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 241 insertions(+), 21 deletions(-) create mode 100644 tests/config/test_api.py diff --git a/mypy.ini b/mypy.ini index c3fbd1a95559..6f4e6f4e268a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -88,6 +88,9 @@ disallow_untyped_defs = False [mypy-tests.*] disallow_untyped_defs = False +[mypy-tests.config.test_api] +disallow_untyped_defs = True + [mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True diff --git a/synapse/config/_util.py b/synapse/config/_util.py index 3edb4b71068f..d3a4b484abb8 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py @@ -33,6 +33,9 @@ def validate_config( config: the configuration value to be validated config_path: the path within the config file. This will be used as a basis for the error message. + + Raises: + ConfigError, if validation fails. """ try: jsonschema.validate(config, json_schema) diff --git a/synapse/config/api.py b/synapse/config/api.py index e46728e73f0a..a29decd6007e 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -13,7 +13,9 @@ # limitations under the License. import logging -from typing import Any, Iterable +from typing import Any, Container, Dict, Iterable, Mapping, Optional, Set, Tuple, Type + +import attr from synapse.api.constants import EventTypes from synapse.config._base import Config, ConfigError @@ -23,19 +25,63 @@ logger = logging.getLogger(__name__) +@attr.s(auto_attribs=True) +class StateKeyFilter(Container[str]): + """A simpler version of StateFilter which ignores event types. + + Represents an optional constraint that state_keys must belong to a given set of + strings called `options`. An empty set of `options` means that there are no + restrictions. + """ + + options: Set[str] + + @classmethod + def any(cls: Type["StateKeyFilter"]) -> "StateKeyFilter": + return cls(set()) + + @classmethod + def only(cls: Type["StateKeyFilter"], state_key: str) -> "StateKeyFilter": + return cls({state_key}) + + def __contains__(self, state_key: object) -> bool: + return not self.options or state_key in self.options + + def add(self, state_key: Optional[str]) -> None: + if state_key is None: + self.options = set() + elif self.options: + self.options.add(state_key) + + class ApiConfig(Config): section = "api" + room_prejoin_state: Mapping[str, StateKeyFilter] + track_puppetted_users_ips: bool + def read_config(self, config: JsonDict, **kwargs: Any) -> None: validate_config(_MAIN_SCHEMA, config, ()) - self.room_prejoin_state = list(self._get_prejoin_state_types(config)) + self.room_prejoin_state = self._build_prejoin_state(config) self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False) - def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]: - """Get the event types to include in the prejoin state - - Parses the config and returns an iterable of the event types to be included. - """ + def _build_prejoin_state(self, config: JsonDict) -> Dict[str, StateKeyFilter]: + prejoin_events = {} + for event_type, state_key in self._get_prejoin_state_entries(config): + if event_type not in prejoin_events: + if state_key is None: + filter = StateKeyFilter.any() + else: + filter = StateKeyFilter.only(state_key) + prejoin_events[event_type] = filter + else: + prejoin_events[event_type].add(state_key) + return prejoin_events + + def _get_prejoin_state_entries( + self, config: JsonDict + ) -> Iterable[Tuple[str, Optional[str]]]: + """Get the event types and state keys to include in the prejoin state.""" room_prejoin_state_config = config.get("room_prejoin_state") or {} # backwards-compatibility support for room_invite_state_types @@ -50,33 +96,39 @@ def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]: logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING) - yield from config["room_invite_state_types"] + for event_type in config["room_invite_state_types"]: + yield event_type, None return if not room_prejoin_state_config.get("disable_default_event_types"): - yield from _DEFAULT_PREJOIN_STATE_TYPES + yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS - yield from room_prejoin_state_config.get("additional_event_types", []) + for entry in room_prejoin_state_config.get("additional_event_types", []): + if isinstance(entry, str): + yield entry, None + else: + yield entry _ROOM_INVITE_STATE_TYPES_WARNING = """\ WARNING: The 'room_invite_state_types' configuration setting is now deprecated, and replaced with 'room_prejoin_state'. New features may not work correctly -unless 'room_invite_state_types' is removed. See the sample configuration file for -details of 'room_prejoin_state'. +unless 'room_invite_state_types' is removed. See the config documentation at + https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state +for details of 'room_prejoin_state'. -------------------------------------------------------------------------------- """ -_DEFAULT_PREJOIN_STATE_TYPES = [ - EventTypes.JoinRules, - EventTypes.CanonicalAlias, - EventTypes.RoomAvatar, - EventTypes.RoomEncryption, - EventTypes.Name, +_DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [ + (EventTypes.JoinRules, ""), + (EventTypes.CanonicalAlias, ""), + (EventTypes.RoomAvatar, ""), + (EventTypes.RoomEncryption, ""), + (EventTypes.Name, ""), # Per MSC1772. - EventTypes.Create, + (EventTypes.Create, ""), # Per MSC3173. - EventTypes.Topic, + (EventTypes.Topic, ""), ] @@ -90,7 +142,17 @@ def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]: "disable_default_event_types": {"type": "boolean"}, "additional_event_types": { "type": "array", - "items": {"type": "string"}, + "items": { + "oneOf": [ + {"type": "string"}, + { + "type": "array", + "items": {"type": "string"}, + "minItems": 2, + "maxItems": 2, + }, + ], + }, }, }, }, diff --git a/tests/config/test_api.py b/tests/config/test_api.py new file mode 100644 index 000000000000..8c65d2f58be4 --- /dev/null +++ b/tests/config/test_api.py @@ -0,0 +1,152 @@ +from unittest import TestCase as StdlibTestCase + +import yaml + +from synapse.config import ConfigError +from synapse.config.api import ApiConfig, StateKeyFilter + +DEFAULT_PREJOIN_STATE = { + "m.room.join_rules": StateKeyFilter.only(""), + "m.room.canonical_alias": StateKeyFilter.only(""), + "m.room.avatar": StateKeyFilter.only(""), + "m.room.encryption": StateKeyFilter.only(""), + "m.room.name": StateKeyFilter.only(""), + "m.room.create": StateKeyFilter.only(""), + "m.room.topic": StateKeyFilter.only(""), +} + + +class TestRoomPrejoinState(StdlibTestCase): + def test_state_key_filter(self) -> None: + """Sanity check the StateKeyFilter class.""" + s = StateKeyFilter.only("foo") + self.assertIn("foo", s) + self.assertNotIn("bar", s) + self.assertNotIn("baz", s) + s.add("bar") + self.assertIn("foo", s) + self.assertIn("bar", s) + self.assertNotIn("baz", s) + + s = StateKeyFilter.any() + self.assertIn("foo", s) + self.assertIn("bar", s) + self.assertIn("baz", s) + s.add("bar") + self.assertIn("foo", s) + self.assertIn("bar", s) + self.assertIn("baz", s) + + def read_config(self, source: str) -> ApiConfig: + config = ApiConfig() + config.read_config(yaml.safe_load(source)) + return config + + def test_no_prejoin_state(self) -> None: + config = self.read_config("foo: bar") + self.assertEqual(config.room_prejoin_state, DEFAULT_PREJOIN_STATE) + + def test_disable_default_event_types(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + """ + ) + self.assertEqual(config.room_prejoin_state, {}) + + def test_event_without_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - foo + """ + ) + self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()}) + + def test_event_with_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + """ + ) + self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.only("bar")}) + + def test_repeated_event_with_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + - [foo, baz] + """ + ) + self.assertEqual( + config.room_prejoin_state, {"foo": StateKeyFilter({"bar", "baz"})} + ) + + def test_no_specific_state_key_overrides_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + - foo + """ + ) + self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()}) + + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - foo + - [foo, bar] + """ + ) + self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()}) + + def test_bad_event_type_entry_raises(self) -> None: + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [a] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [a, b, c] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [true, 1.23] + """ + ) From 8ed2d936ac457db110f3c40b9f3e37dd0638f850 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 7 Dec 2022 20:06:30 +0000 Subject: [PATCH 03/13] Read new config --- .../storage/databases/main/events_worker.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 01e935edef53..7f6ecfef127c 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -20,10 +20,10 @@ TYPE_CHECKING, Any, Collection, - Container, Dict, Iterable, List, + Mapping, MutableMapping, Optional, Set, @@ -46,6 +46,7 @@ RoomVersion, RoomVersions, ) +from synapse.config.api import StateKeyFilter from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event @@ -879,7 +880,7 @@ def _get_events_from_local_cache( async def get_stripped_room_state_from_event_context( self, context: EventContext, - state_types_to_include: Container[str], + state_keys_to_include: Mapping[str, StateKeyFilter], membership_user_id: Optional[str] = None, ) -> List[JsonDict]: """ @@ -892,7 +893,7 @@ async def get_stripped_room_state_from_event_context( Args: context: The event context to retrieve state of the room from. - state_types_to_include: The type of state events to include. + state_keys_to_include: The state events to include, for each event type. membership_user_id: An optional user ID to include the stripped membership state events of. This is useful when generating the stripped state of a room for invites. We want to send membership events of the inviter, so that the @@ -907,12 +908,22 @@ async def get_stripped_room_state_from_event_context( # non-None. assert current_state_ids is not None + def should_include(t: str, s: str) -> bool: + if t in state_keys_to_include and s in state_keys_to_include[t]: + return True + if ( + membership_user_id + and t == EventTypes.Member + and s == membership_user_id + ): + return True + return False + # The state to include state_to_include_ids = [ e_id - for k, e_id in current_state_ids.items() - if k[0] in state_types_to_include - or (membership_user_id and k == (EventTypes.Member, membership_user_id)) + for (event_type, state_key), e_id in current_state_ids.items() + if should_include(event_type, state_key) ] state_to_include = await self.get_events(state_to_include_ids) From eb0bbbe1b5c0680435f52b93ac41f785cef73d1a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 7 Dec 2022 20:20:34 +0000 Subject: [PATCH 04/13] Don't use trial/our TestCase where it's not needed Before: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m2.277s user 0m2.186s sys 0m0.083s ``` After: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m0.566s user 0m0.508s sys 0m0.056s ``` --- tests/events/test_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index b1c47efac7c0..f77a789b4bef 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest as stdlib_unittest + from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict @@ -23,8 +25,6 @@ ) from synapse.util.frozenutils import freeze -from tests import unittest - def MockEvent(**kwargs): if "event_id" not in kwargs: @@ -34,7 +34,7 @@ def MockEvent(**kwargs): return make_event_from_dict(kwargs) -class PruneEventTestCase(unittest.TestCase): +class PruneEventTestCase(stdlib_unittest.TestCase): def run_test(self, evdict, matchdict, **kwargs): """ Asserts that a new event constructed with `evdict` will look like @@ -391,7 +391,7 @@ def test_member(self): ) -class SerializeEventTestCase(unittest.TestCase): +class SerializeEventTestCase(stdlib_unittest.TestCase): def serialize(self, ev, fields): return serialize_event( ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) @@ -513,7 +513,7 @@ def test_event_fields_fail_if_fields_not_str(self): ) -class CopyPowerLevelsContentTestCase(unittest.TestCase): +class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def setUp(self) -> None: self.test_content = { "ban": 50, From c6199a3f21b4139506b8594635256e54a532cccc Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 7 Dec 2022 20:22:54 +0000 Subject: [PATCH 05/13] Helper to upsert to event fields without exceeding size limits. --- synapse/events/utils.py | 32 +++++++++++++++++++++++++++++++- tests/events/test_utils.py | 25 +++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 71853caad8c7..13fa93afb87b 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -28,8 +28,14 @@ ) import attr +from canonicaljson import encode_canonical_json -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import ( + MAX_PDU_SIZE, + EventContentFields, + EventTypes, + RelationTypes, +) from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.types import JsonDict @@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None: elif not isinstance(value, (bool, str)) and value is not None: # Other potential JSON values (bool, None, str) are safe. raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON) + + +def maybe_upsert_event_field( + event: EventBase, container: JsonDict, key: str, value: object +) -> bool: + """Upsert an event field, but only if this doesn't make the event too large. + + Returns true iff the upsert took place. + """ + if key in container: + old_value: object = container[key] + container[key] = value + # NB: here and below, we assume that passing a non-None `time_now` argument to + # get_pdu_json doesn't increase the size of the encoded result. + upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE + if not upsert_okay: + container[key] = old_value + else: + container[key] = value + upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE + if not upsert_okay: + del container[key] + + return upsert_okay diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index f77a789b4bef..a79256846fb3 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -20,6 +20,7 @@ from synapse.events.utils import ( SerializeEventConfig, copy_and_fixup_power_levels_contents, + maybe_upsert_event_field, prune_event, serialize_event, ) @@ -34,6 +35,30 @@ def MockEvent(**kwargs): return make_event_from_dict(kwargs) +class TestMaybeUpsertEventField(stdlib_unittest.TestCase): + def test_update_okay(self) -> None: + event = make_event_from_dict({"event_id": "$1234"}) + success = maybe_upsert_event_field(event, event.unsigned, "key", "value") + self.assertTrue(success) + self.assertEqual(event.unsigned["key"], "value") + + def test_update_not_okay(self) -> None: + event = make_event_from_dict({"event_id": "$1234"}) + LARGE_STRING = "a" * 100_000 + success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) + self.assertFalse(success) + self.assertNotIn("key", event.unsigned) + + def test_update_not_okay_leaves_original_value(self) -> None: + event = make_event_from_dict( + {"event_id": "$1234", "unsigned": {"key": "value"}} + ) + LARGE_STRING = "a" * 100_000 + success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) + self.assertFalse(success) + self.assertEqual(event.unsigned["key"], "value") + + class PruneEventTestCase(stdlib_unittest.TestCase): def run_test(self, evdict, matchdict, **kwargs): """ From ddc563c18e682cbea1d442e695e2babb37ae471c Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 7 Dec 2022 20:23:42 +0000 Subject: [PATCH 06/13] Use helper when adding invite/knock state Now that we allow admins to include events in prejoin room state with arbitrary state keys, be a good Matrix citizen and ensure they don't accidentally create an oversized event. --- synapse/handlers/message.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5cbe89f4fddf..8a6c4ca7904b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -50,6 +50,7 @@ from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext +from synapse.events.utils import maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler from synapse.logging import opentracing @@ -1739,12 +1740,15 @@ async def persist_and_notify_client_events( if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: - event.unsigned[ - "invite_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, - membership_user_id=event.sender, + maybe_upsert_event_field( + event, + event.unsigned, + "invite_room_state", + await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + membership_user_id=event.sender, + ), ) invitee = UserID.from_string(event.state_key) @@ -1762,11 +1766,14 @@ async def persist_and_notify_client_events( event.signatures.update(returned_invite.signatures) if event.content["membership"] == Membership.KNOCK: - event.unsigned[ - "knock_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, + maybe_upsert_event_field( + event, + event.unsigned, + "knock_room_state", + await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + ), ) if event.type == EventTypes.Redaction: From 33a559f80f416525d20e0ea9ff72f282848cceed Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 7 Dec 2022 20:32:40 +0000 Subject: [PATCH 07/13] Changelog --- changelog.d/14642.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/14642.feature diff --git a/changelog.d/14642.feature b/changelog.d/14642.feature new file mode 100644 index 000000000000..cbc9db10c309 --- /dev/null +++ b/changelog.d/14642.feature @@ -0,0 +1 @@ +Allow selecting "prejoin" events by state keys in addition to event types. From fe30c5d1a457e4e6ed10a513f3b99209ddf804f3 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 16:37:18 +0000 Subject: [PATCH 08/13] Move StateFilter tests should have done this in #14668 --- tests/storage/test_state.py | 623 +---------------------------------- tests/types/__init__.py | 0 tests/types/test_state.py | 627 ++++++++++++++++++++++++++++++++++++ 3 files changed, 628 insertions(+), 622 deletions(-) create mode 100644 tests/types/__init__.py create mode 100644 tests/types/test_state.py diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index a433e7087063..bad7f0bc6026 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -26,7 +26,7 @@ from synapse.types.state import StateFilter from synapse.util import Clock -from tests.unittest import HomeserverTestCase, TestCase +from tests.unittest import HomeserverTestCase logger = logging.getLogger(__name__) @@ -494,624 +494,3 @@ def test_get_state_for_event(self) -> None: self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - - -class StateFilterDifferenceTestCase(TestCase): - def assert_difference( - self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter - ) -> None: - self.assertEqual( - minuend.approx_difference(subtrahend), - expected, - f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", - ) - - def test_state_filter_difference_no_include_other_minus_no_include_other( - self, - ) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), both a and b do not have the - include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - StateFilter.freeze({EventTypes.Create: None}, include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:spqr"}}, - include_others=False, - ), - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.CanonicalAlias: {""}}, - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), only a has the include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Create: None, - EventTypes.Member: set(), - EventTypes.CanonicalAlias: set(), - }, - include_others=True, - ), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - # This also shows that the resultant state filter is normalised. - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=True), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.Create: {""}, - }, - include_others=False, - ), - StateFilter(types=frozendict(), include_others=True), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter( - types=frozendict(), - include_others=True, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.CanonicalAlias: {""}, - EventTypes.Member: set(), - }, - include_others=True, - ), - ) - - # (specific state keys) - (specific state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - ) - - def test_state_filter_difference_include_other_minus_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), both a and b have the include_others - flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=True, - ), - StateFilter(types=frozendict(), include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=True), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter( - types=frozendict(), - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - EventTypes.Create: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.Create: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.Create: {""}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), only b has the include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=True, - ), - StateFilter(types=frozendict(), include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:spqr"}}, - include_others=True, - ), - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter( - types=frozendict(), - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_simple_cases(self) -> None: - """ - Tests some very simple cases of the StateFilter approx_difference, - that are not explicitly tested by the more in-depth tests. - """ - - self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none()) - - self.assert_difference( - StateFilter.all(), - StateFilter.none(), - StateFilter.all(), - ) - - -class StateFilterTestCase(TestCase): - def test_return_expanded(self) -> None: - """ - Tests the behaviour of the return_expanded() function that expands - StateFilters to include more state types (for the sake of cache hit rate). - """ - - self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) - - self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) - - # Concrete-only state filters stay the same - # (Case: mixed filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": {""}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": {""}, - }, - include_others=False, - ), - ) - - # Concrete-only state filters stay the same - # (Case: non-member-only filter) - self.assertEqual( - StateFilter.freeze( - {"some.other.state.type": {""}}, include_others=False - ).return_expanded(), - StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), - ) - - # Concrete-only state filters stay the same - # (Case: member-only filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - }, - include_others=False, - ), - ) - - # Wildcard member-only state filters stay the same - self.assertEqual( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # If there is a wildcard in the non-member portion of the filter, - # it's expanded to include ALL non-member events. - # (Case: mixed filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": None, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, - include_others=True, - ), - ) - - # If there is a wildcard in the non-member portion of the filter, - # it's expanded to include ALL non-member events. - # (Case: non-member-only filter) - self.assertEqual( - StateFilter.freeze( - { - "some.other.state.type": None, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze({EventTypes.Member: set()}, include_others=True), - ) - self.assertEqual( - StateFilter.freeze( - { - "some.other.state.type": None, - "yet.another.state.type": {"wombat"}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze({EventTypes.Member: set()}, include_others=True), - ) diff --git a/tests/types/__init__.py b/tests/types/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/types/test_state.py b/tests/types/test_state.py new file mode 100644 index 000000000000..eb809f9fb739 --- /dev/null +++ b/tests/types/test_state.py @@ -0,0 +1,627 @@ +from frozendict import frozendict + +from synapse.api.constants import EventTypes +from synapse.types.state import StateFilter + +from tests.unittest import TestCase + + +class StateFilterDifferenceTestCase(TestCase): + def assert_difference( + self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter + ) -> None: + self.assertEqual( + minuend.approx_difference(subtrahend), + expected, + f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", + ) + + def test_state_filter_difference_no_include_other_minus_no_include_other( + self, + ) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), both a and b do not have the + include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + StateFilter.freeze({EventTypes.Create: None}, include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:spqr"}}, + include_others=False, + ), + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.CanonicalAlias: {""}}, + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), only a has the include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Create: None, + EventTypes.Member: set(), + EventTypes.CanonicalAlias: set(), + }, + include_others=True, + ), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + # This also shows that the resultant state filter is normalised. + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=True), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.Create: {""}, + }, + include_others=False, + ), + StateFilter(types=frozendict(), include_others=True), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter( + types=frozendict(), + include_others=True, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.CanonicalAlias: {""}, + EventTypes.Member: set(), + }, + include_others=True, + ), + ) + + # (specific state keys) - (specific state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + ) + + def test_state_filter_difference_include_other_minus_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), both a and b have the include_others + flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=True, + ), + StateFilter(types=frozendict(), include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=True), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter( + types=frozendict(), + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + EventTypes.Create: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.Create: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.Create: {""}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), only b has the include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=True, + ), + StateFilter(types=frozendict(), include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:spqr"}}, + include_others=True, + ), + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter( + types=frozendict(), + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_simple_cases(self) -> None: + """ + Tests some very simple cases of the StateFilter approx_difference, + that are not explicitly tested by the more in-depth tests. + """ + + self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none()) + + self.assert_difference( + StateFilter.all(), + StateFilter.none(), + StateFilter.all(), + ) + + +class StateFilterTestCase(TestCase): + def test_return_expanded(self) -> None: + """ + Tests the behaviour of the return_expanded() function that expands + StateFilters to include more state types (for the sake of cache hit rate). + """ + + self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) + + self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) + + # Concrete-only state filters stay the same + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ), + ) + + # Concrete-only state filters stay the same + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + {"some.other.state.type": {""}}, include_others=False + ).return_expanded(), + StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), + ) + + # Concrete-only state filters stay the same + # (Case: member-only filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ), + ) + + # Wildcard member-only state filters stay the same + self.assertEqual( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, + include_others=True, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + "yet.another.state.type": {"wombat"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) From 1bc0f13e7daea12348e147bf8a630d4872a04fc1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 16:42:38 +0000 Subject: [PATCH 09/13] Add extra methods to StateFilter --- synapse/types/state.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/synapse/types/state.py b/synapse/types/state.py index 0004d955b434..743a4f9217ee 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -118,6 +118,15 @@ def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": ) ) + def to_types(self) -> Iterable[Tuple[str, Optional[str]]]: + """The inverse to `from_types`.""" + for (event_type, state_keys) in self.types.items(): + if state_keys is None: + yield event_type, None + else: + for state_key in state_keys: + yield event_type, state_key + @staticmethod def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": """Creates a filter that returns all non-member events, plus the member @@ -343,6 +352,15 @@ def concrete_types(self) -> List[Tuple[str, str]]: for s in state_keys ] + def wildcard_types(self) -> List[str]: + """Returns a list of event types which require us to fetch all state keys. + This will be empty unless `has_wildcards` returns True. + + Returns: + A list of event types. + """ + return [t for t, state_keys in self.types.items() if state_keys is None] + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: """Return the filter split into two: one which assumes it's exclusively matching against member state, and one which assumes it's matching From 11124ed8c5aeb45c7a782f3ed41a188d10261f6e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 16:49:39 +0000 Subject: [PATCH 10/13] Use StateFilter --- synapse/config/api.py | 53 ++------------- .../storage/databases/main/events_worker.py | 38 ++++------- tests/config/test_api.py | 67 +++++++++---------- 3 files changed, 50 insertions(+), 108 deletions(-) diff --git a/synapse/config/api.py b/synapse/config/api.py index a29decd6007e..27d50d118f3f 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -13,71 +13,30 @@ # limitations under the License. import logging -from typing import Any, Container, Dict, Iterable, Mapping, Optional, Set, Tuple, Type - -import attr +from typing import Any, Iterable, Optional, Tuple from synapse.api.constants import EventTypes from synapse.config._base import Config, ConfigError from synapse.config._util import validate_config from synapse.types import JsonDict +from synapse.types.state import StateFilter logger = logging.getLogger(__name__) -@attr.s(auto_attribs=True) -class StateKeyFilter(Container[str]): - """A simpler version of StateFilter which ignores event types. - - Represents an optional constraint that state_keys must belong to a given set of - strings called `options`. An empty set of `options` means that there are no - restrictions. - """ - - options: Set[str] - - @classmethod - def any(cls: Type["StateKeyFilter"]) -> "StateKeyFilter": - return cls(set()) - - @classmethod - def only(cls: Type["StateKeyFilter"], state_key: str) -> "StateKeyFilter": - return cls({state_key}) - - def __contains__(self, state_key: object) -> bool: - return not self.options or state_key in self.options - - def add(self, state_key: Optional[str]) -> None: - if state_key is None: - self.options = set() - elif self.options: - self.options.add(state_key) - - class ApiConfig(Config): section = "api" - room_prejoin_state: Mapping[str, StateKeyFilter] + room_prejoin_state: StateFilter track_puppetted_users_ips: bool def read_config(self, config: JsonDict, **kwargs: Any) -> None: validate_config(_MAIN_SCHEMA, config, ()) - self.room_prejoin_state = self._build_prejoin_state(config) + self.room_prejoin_state = StateFilter.from_types( + self._get_prejoin_state_entries(config) + ) self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False) - def _build_prejoin_state(self, config: JsonDict) -> Dict[str, StateKeyFilter]: - prejoin_events = {} - for event_type, state_key in self._get_prejoin_state_entries(config): - if event_type not in prejoin_events: - if state_key is None: - filter = StateKeyFilter.any() - else: - filter = StateKeyFilter.only(state_key) - prejoin_events[event_type] = filter - else: - prejoin_events[event_type].add(state_key) - return prejoin_events - def _get_prejoin_state_entries( self, config: JsonDict ) -> Iterable[Tuple[str, Optional[str]]]: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7f6ecfef127c..a36549e7c24f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -16,6 +16,7 @@ import threading import weakref from enum import Enum, auto +from itertools import chain from typing import ( TYPE_CHECKING, Any, @@ -23,7 +24,6 @@ Dict, Iterable, List, - Mapping, MutableMapping, Optional, Set, @@ -46,7 +46,6 @@ RoomVersion, RoomVersions, ) -from synapse.config.api import StateKeyFilter from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event @@ -77,6 +76,7 @@ ) from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList @@ -880,7 +880,7 @@ def _get_events_from_local_cache( async def get_stripped_room_state_from_event_context( self, context: EventContext, - state_keys_to_include: Mapping[str, StateKeyFilter], + state_keys_to_include: StateFilter, membership_user_id: Optional[str] = None, ) -> List[JsonDict]: """ @@ -902,31 +902,21 @@ async def get_stripped_room_state_from_event_context( Returns: A list of dictionaries, each representing a stripped state event from the room. """ - current_state_ids = await context.get_current_state_ids() + if membership_user_id: + types = chain( + state_keys_to_include.to_types(), + [(EventTypes.Member, membership_user_id)], + ) + filter = StateFilter.from_types(types) + else: + filter = state_keys_to_include + selected_state_ids = await context.get_current_state_ids(filter) # We know this event is not an outlier, so this must be # non-None. - assert current_state_ids is not None - - def should_include(t: str, s: str) -> bool: - if t in state_keys_to_include and s in state_keys_to_include[t]: - return True - if ( - membership_user_id - and t == EventTypes.Member - and s == membership_user_id - ): - return True - return False - - # The state to include - state_to_include_ids = [ - e_id - for (event_type, state_key), e_id in current_state_ids.items() - if should_include(event_type, state_key) - ] + assert selected_state_ids is not None - state_to_include = await self.get_events(state_to_include_ids) + state_to_include = await self.get_events(selected_state_ids.values()) return [ { diff --git a/tests/config/test_api.py b/tests/config/test_api.py index 8c65d2f58be4..6773c9a2773a 100644 --- a/tests/config/test_api.py +++ b/tests/config/test_api.py @@ -3,40 +3,21 @@ import yaml from synapse.config import ConfigError -from synapse.config.api import ApiConfig, StateKeyFilter - -DEFAULT_PREJOIN_STATE = { - "m.room.join_rules": StateKeyFilter.only(""), - "m.room.canonical_alias": StateKeyFilter.only(""), - "m.room.avatar": StateKeyFilter.only(""), - "m.room.encryption": StateKeyFilter.only(""), - "m.room.name": StateKeyFilter.only(""), - "m.room.create": StateKeyFilter.only(""), - "m.room.topic": StateKeyFilter.only(""), +from synapse.config.api import ApiConfig +from synapse.types.state import StateFilter + +DEFAULT_PREJOIN_STATE_PAIRS = { + ("m.room.join_rules", ""), + ("m.room.canonical_alias", ""), + ("m.room.avatar", ""), + ("m.room.encryption", ""), + ("m.room.name", ""), + ("m.room.create", ""), + ("m.room.topic", ""), } class TestRoomPrejoinState(StdlibTestCase): - def test_state_key_filter(self) -> None: - """Sanity check the StateKeyFilter class.""" - s = StateKeyFilter.only("foo") - self.assertIn("foo", s) - self.assertNotIn("bar", s) - self.assertNotIn("baz", s) - s.add("bar") - self.assertIn("foo", s) - self.assertIn("bar", s) - self.assertNotIn("baz", s) - - s = StateKeyFilter.any() - self.assertIn("foo", s) - self.assertIn("bar", s) - self.assertIn("baz", s) - s.add("bar") - self.assertIn("foo", s) - self.assertIn("bar", s) - self.assertIn("baz", s) - def read_config(self, source: str) -> ApiConfig: config = ApiConfig() config.read_config(yaml.safe_load(source)) @@ -44,7 +25,10 @@ def read_config(self, source: str) -> ApiConfig: def test_no_prejoin_state(self) -> None: config = self.read_config("foo: bar") - self.assertEqual(config.room_prejoin_state, DEFAULT_PREJOIN_STATE) + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS + ) def test_disable_default_event_types(self) -> None: config = self.read_config( @@ -53,7 +37,7 @@ def test_disable_default_event_types(self) -> None: disable_default_event_types: true """ ) - self.assertEqual(config.room_prejoin_state, {}) + self.assertEqual(config.room_prejoin_state, StateFilter.none()) def test_event_without_state_key(self) -> None: config = self.read_config( @@ -64,7 +48,8 @@ def test_event_without_state_key(self) -> None: - foo """ ) - self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()}) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) def test_event_with_specific_state_key(self) -> None: config = self.read_config( @@ -75,7 +60,11 @@ def test_event_with_specific_state_key(self) -> None: - [foo, bar] """ ) - self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.only("bar")}) + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), + {("foo", "bar")}, + ) def test_repeated_event_with_specific_state_key(self) -> None: config = self.read_config( @@ -87,8 +76,10 @@ def test_repeated_event_with_specific_state_key(self) -> None: - [foo, baz] """ ) + self.assertFalse(config.room_prejoin_state.has_wildcards()) self.assertEqual( - config.room_prejoin_state, {"foo": StateKeyFilter({"bar", "baz"})} + set(config.room_prejoin_state.concrete_types()), + {("foo", "bar"), ("foo", "baz")}, ) def test_no_specific_state_key_overrides_specific_state_key(self) -> None: @@ -101,7 +92,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None: - foo """ ) - self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()}) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) config = self.read_config( """ @@ -112,7 +104,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None: - [foo, bar] """ ) - self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()}) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) def test_bad_event_type_entry_raises(self) -> None: with self.assertRaises(ConfigError): From af032d218f1264836e6b8f919b5219042b55f1e3 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 17:07:37 +0000 Subject: [PATCH 11/13] Ensure test file enforces typed defs; alphabetise --- mypy.ini | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mypy.ini b/mypy.ini index 11fd017385b9..c10865632ece 100644 --- a/mypy.ini +++ b/mypy.ini @@ -91,6 +91,9 @@ disallow_untyped_defs = False [mypy-tests.config.test_api] disallow_untyped_defs = True +[mypy-tests.federation.transport.test_client] +disallow_untyped_defs = True + [mypy-tests.handlers.test_sso] disallow_untyped_defs = True @@ -103,7 +106,7 @@ disallow_untyped_defs = True [mypy-tests.push.test_bulk_push_rule_evaluator] disallow_untyped_defs = True -[mypy-tests.test_server] +[mypy-tests.rest.*] disallow_untyped_defs = True [mypy-tests.state.test_profile] @@ -112,10 +115,10 @@ disallow_untyped_defs = True [mypy-tests.storage.*] disallow_untyped_defs = True -[mypy-tests.rest.*] +[mypy-tests.test_server] disallow_untyped_defs = True -[mypy-tests.federation.transport.test_client] +[mypy-tests.types.*] disallow_untyped_defs = True [mypy-tests.util.caches.*] From 837f6c2ce9dd2bdcb6c8f65d8cd0a227bf40acd5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 18:10:28 +0000 Subject: [PATCH 12/13] Workaround surprising get_current_state_ids --- synapse/storage/databases/main/events_worker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a36549e7c24f..f8e920762894 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -912,6 +912,10 @@ async def get_stripped_room_state_from_event_context( filter = state_keys_to_include selected_state_ids = await context.get_current_state_ids(filter) + # Confusingly, get_current_state_events may return events that are discarded by + # the filter, if they're in context._state_delta_due_to_event. Strip these away. + selected_state_ids = filter.filter_state(selected_state_ids) + # We know this event is not an outlier, so this must be # non-None. assert selected_state_ids is not None From 23ca964eeec3c579aa572e9d9f291fd7617a0508 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 18:18:18 +0000 Subject: [PATCH 13/13] Whoops, fix mypy --- synapse/storage/databases/main/events_worker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index f8e920762894..318fd7dc7133 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -912,14 +912,14 @@ async def get_stripped_room_state_from_event_context( filter = state_keys_to_include selected_state_ids = await context.get_current_state_ids(filter) - # Confusingly, get_current_state_events may return events that are discarded by - # the filter, if they're in context._state_delta_due_to_event. Strip these away. - selected_state_ids = filter.filter_state(selected_state_ids) - # We know this event is not an outlier, so this must be # non-None. assert selected_state_ids is not None + # Confusingly, get_current_state_events may return events that are discarded by + # the filter, if they're in context._state_delta_due_to_event. Strip these away. + selected_state_ids = filter.filter_state(selected_state_ids) + state_to_include = await self.get_events(selected_state_ids.values()) return [