Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit a19d01c

Browse files
authored
Support filtering by relations per MSC3440 (#11236)
Adds experimental support for `relation_types` and `relation_senders` fields for filters.
1 parent 4b3e30c commit a19d01c

File tree

15 files changed

+680
-110
lines changed

15 files changed

+680
-110
lines changed

changelog.d/11236.feature

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support filtering by relation senders & types per [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).

synapse/api/filtering.py

+84-31
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright 2015, 2016 OpenMarket Ltd
22
# Copyright 2017 Vector Creations Ltd
33
# Copyright 2018-2019 New Vector Ltd
4-
# Copyright 2019 The Matrix.org Foundation C.I.C.
4+
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
55
#
66
# Licensed under the Apache License, Version 2.0 (the "License");
77
# you may not use this file except in compliance with the License.
@@ -86,6 +86,9 @@
8686
# cf https://github.com/matrix-org/matrix-doc/pull/2326
8787
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
8888
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
89+
# MSC3440, filtering by event relations.
90+
"io.element.relation_senders": {"type": "array", "items": {"type": "string"}},
91+
"io.element.relation_types": {"type": "array", "items": {"type": "string"}},
8992
},
9093
}
9194

@@ -146,14 +149,16 @@ def matrix_user_id_validator(user_id_str: str) -> UserID:
146149

147150
class Filtering:
148151
def __init__(self, hs: "HomeServer"):
149-
super().__init__()
152+
self._hs = hs
150153
self.store = hs.get_datastore()
151154

155+
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
156+
152157
async def get_user_filter(
153158
self, user_localpart: str, filter_id: Union[int, str]
154159
) -> "FilterCollection":
155160
result = await self.store.get_user_filter(user_localpart, filter_id)
156-
return FilterCollection(result)
161+
return FilterCollection(self._hs, result)
157162

158163
def add_user_filter(
159164
self, user_localpart: str, user_filter: JsonDict
@@ -191,21 +196,22 @@ def check_valid_filter(self, user_filter_json: JsonDict) -> None:
191196

192197

193198
class FilterCollection:
194-
def __init__(self, filter_json: JsonDict):
199+
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
195200
self._filter_json = filter_json
196201

197202
room_filter_json = self._filter_json.get("room", {})
198203

199204
self._room_filter = Filter(
200-
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}
205+
hs,
206+
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")},
201207
)
202208

203-
self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
204-
self._room_state_filter = Filter(room_filter_json.get("state", {}))
205-
self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
206-
self._room_account_data = Filter(room_filter_json.get("account_data", {}))
207-
self._presence_filter = Filter(filter_json.get("presence", {}))
208-
self._account_data = Filter(filter_json.get("account_data", {}))
209+
self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {}))
210+
self._room_state_filter = Filter(hs, room_filter_json.get("state", {}))
211+
self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {}))
212+
self._room_account_data = Filter(hs, room_filter_json.get("account_data", {}))
213+
self._presence_filter = Filter(hs, filter_json.get("presence", {}))
214+
self._account_data = Filter(hs, filter_json.get("account_data", {}))
209215

210216
self.include_leave = filter_json.get("room", {}).get("include_leave", False)
211217
self.event_fields = filter_json.get("event_fields", [])
@@ -232,25 +238,37 @@ def lazy_load_members(self) -> bool:
232238
def include_redundant_members(self) -> bool:
233239
return self._room_state_filter.include_redundant_members
234240

235-
def filter_presence(
241+
async def filter_presence(
236242
self, events: Iterable[UserPresenceState]
237243
) -> List[UserPresenceState]:
238-
return self._presence_filter.filter(events)
244+
return await self._presence_filter.filter(events)
239245

240-
def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
241-
return self._account_data.filter(events)
246+
async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
247+
return await self._account_data.filter(events)
242248

243-
def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
244-
return self._room_state_filter.filter(self._room_filter.filter(events))
249+
async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
250+
return await self._room_state_filter.filter(
251+
await self._room_filter.filter(events)
252+
)
245253

246-
def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
247-
return self._room_timeline_filter.filter(self._room_filter.filter(events))
254+
async def filter_room_timeline(
255+
self, events: Iterable[EventBase]
256+
) -> List[EventBase]:
257+
return await self._room_timeline_filter.filter(
258+
await self._room_filter.filter(events)
259+
)
248260

249-
def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
250-
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
261+
async def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
262+
return await self._room_ephemeral_filter.filter(
263+
await self._room_filter.filter(events)
264+
)
251265

252-
def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
253-
return self._room_account_data.filter(self._room_filter.filter(events))
266+
async def filter_room_account_data(
267+
self, events: Iterable[JsonDict]
268+
) -> List[JsonDict]:
269+
return await self._room_account_data.filter(
270+
await self._room_filter.filter(events)
271+
)
254272

255273
def blocks_all_presence(self) -> bool:
256274
return (
@@ -274,7 +292,9 @@ def blocks_all_room_timeline(self) -> bool:
274292

275293

276294
class Filter:
277-
def __init__(self, filter_json: JsonDict):
295+
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
296+
self._hs = hs
297+
self._store = hs.get_datastore()
278298
self.filter_json = filter_json
279299

280300
self.limit = filter_json.get("limit", 10)
@@ -297,6 +317,20 @@ def __init__(self, filter_json: JsonDict):
297317
self.labels = filter_json.get("org.matrix.labels", None)
298318
self.not_labels = filter_json.get("org.matrix.not_labels", [])
299319

320+
# Ideally these would be rejected at the endpoint if they were provided
321+
# and not supported, but that would involve modifying the JSON schema
322+
# based on the homeserver configuration.
323+
if hs.config.experimental.msc3440_enabled:
324+
self.relation_senders = self.filter_json.get(
325+
"io.element.relation_senders", None
326+
)
327+
self.relation_types = self.filter_json.get(
328+
"io.element.relation_types", None
329+
)
330+
else:
331+
self.relation_senders = None
332+
self.relation_types = None
333+
300334
def filters_all_types(self) -> bool:
301335
return "*" in self.not_types
302336

@@ -306,7 +340,7 @@ def filters_all_senders(self) -> bool:
306340
def filters_all_rooms(self) -> bool:
307341
return "*" in self.not_rooms
308342

309-
def check(self, event: FilterEvent) -> bool:
343+
def _check(self, event: FilterEvent) -> bool:
310344
"""Checks whether the filter matches the given event.
311345
312346
Args:
@@ -420,8 +454,30 @@ def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
420454

421455
return room_ids
422456

423-
def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
424-
return list(filter(self.check, events))
457+
async def _check_event_relations(
458+
self, events: Iterable[FilterEvent]
459+
) -> List[FilterEvent]:
460+
# The event IDs to check, mypy doesn't understand the ifinstance check.
461+
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
462+
event_ids_to_keep = set(
463+
await self._store.events_have_relations(
464+
event_ids, self.relation_senders, self.relation_types
465+
)
466+
)
467+
468+
return [
469+
event
470+
for event in events
471+
if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep
472+
]
473+
474+
async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
475+
result = [event for event in events if self._check(event)]
476+
477+
if self.relation_senders or self.relation_types:
478+
return await self._check_event_relations(result)
479+
480+
return result
425481

426482
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
427483
"""Returns a new filter with the given room IDs appended.
@@ -433,7 +489,7 @@ def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
433489
filter: A new filter including the given rooms and the old
434490
filter's rooms.
435491
"""
436-
newFilter = Filter(self.filter_json)
492+
newFilter = Filter(self._hs, self.filter_json)
437493
newFilter.rooms += room_ids
438494
return newFilter
439495

@@ -444,6 +500,3 @@ def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
444500
return actual_value.startswith(type_prefix)
445501
else:
446502
return actual_value == filter_value
447-
448-
449-
DEFAULT_FILTER_COLLECTION = FilterCollection({})

synapse/handlers/pagination.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ async def get_messages(
424424

425425
if events:
426426
if event_filter:
427-
events = event_filter.filter(events)
427+
events = await event_filter.filter(events)
428428

429429
events = await filter_events_for_client(
430430
self.storage, user_id, events, is_peeking=(member_event_id is None)

synapse/handlers/room.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1158,8 +1158,10 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
11581158
)
11591159

11601160
if event_filter:
1161-
results["events_before"] = event_filter.filter(results["events_before"])
1162-
results["events_after"] = event_filter.filter(results["events_after"])
1161+
results["events_before"] = await event_filter.filter(
1162+
results["events_before"]
1163+
)
1164+
results["events_after"] = await event_filter.filter(results["events_after"])
11631165

11641166
results["events_before"] = await filter_evts(results["events_before"])
11651167
results["events_after"] = await filter_evts(results["events_after"])
@@ -1195,7 +1197,7 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
11951197

11961198
state_events = list(state[last_event_id].values())
11971199
if event_filter:
1198-
state_events = event_filter.filter(state_events)
1200+
state_events = await event_filter.filter(state_events)
11991201

12001202
results["state"] = await filter_evts(state_events)
12011203

synapse/handlers/search.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ async def search(
180180
% (set(group_keys) - {"room_id", "sender"},),
181181
)
182182

183-
search_filter = Filter(filter_dict)
183+
search_filter = Filter(self.hs, filter_dict)
184184

185185
# TODO: Search through left rooms too
186186
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
@@ -242,7 +242,7 @@ async def search(
242242

243243
rank_map.update({r["event"].event_id: r["rank"] for r in results})
244244

245-
filtered_events = search_filter.filter([r["event"] for r in results])
245+
filtered_events = await search_filter.filter([r["event"] for r in results])
246246

247247
events = await filter_events_for_client(
248248
self.storage, user.to_string(), filtered_events
@@ -292,7 +292,9 @@ async def search(
292292

293293
rank_map.update({r["event"].event_id: r["rank"] for r in results})
294294

295-
filtered_events = search_filter.filter([r["event"] for r in results])
295+
filtered_events = await search_filter.filter(
296+
[r["event"] for r in results]
297+
)
296298

297299
events = await filter_events_for_client(
298300
self.storage, user.to_string(), filtered_events

synapse/handlers/sync.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ async def _load_filtered_recents(
510510
log_kv({"limited": limited})
511511

512512
if potential_recents:
513-
recents = sync_config.filter_collection.filter_room_timeline(
513+
recents = await sync_config.filter_collection.filter_room_timeline(
514514
potential_recents
515515
)
516516
log_kv({"recents_after_sync_filtering": len(recents)})
@@ -575,8 +575,8 @@ async def _load_filtered_recents(
575575

576576
log_kv({"loaded_recents": len(events)})
577577

578-
loaded_recents = sync_config.filter_collection.filter_room_timeline(
579-
events
578+
loaded_recents = (
579+
await sync_config.filter_collection.filter_room_timeline(events)
580580
)
581581

582582
log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)})
@@ -1015,7 +1015,7 @@ async def compute_state_delta(
10151015

10161016
return {
10171017
(e.type, e.state_key): e
1018-
for e in sync_config.filter_collection.filter_room_state(
1018+
for e in await sync_config.filter_collection.filter_room_state(
10191019
list(state.values())
10201020
)
10211021
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
@@ -1383,7 +1383,7 @@ async def _generate_sync_entry_for_account_data(
13831383
sync_config.user
13841384
)
13851385

1386-
account_data_for_user = sync_config.filter_collection.filter_account_data(
1386+
account_data_for_user = await sync_config.filter_collection.filter_account_data(
13871387
[
13881388
{"type": account_data_type, "content": content}
13891389
for account_data_type, content in account_data.items()
@@ -1448,7 +1448,7 @@ async def _generate_sync_entry_for_presence(
14481448
# Deduplicate the presence entries so that there's at most one per user
14491449
presence = list({p.user_id: p for p in presence}.values())
14501450

1451-
presence = sync_config.filter_collection.filter_presence(presence)
1451+
presence = await sync_config.filter_collection.filter_presence(presence)
14521452

14531453
sync_result_builder.presence = presence
14541454

@@ -2021,12 +2021,14 @@ async def _generate_room_entry(
20212021
)
20222022

20232023
account_data_events = (
2024-
sync_config.filter_collection.filter_room_account_data(
2024+
await sync_config.filter_collection.filter_room_account_data(
20252025
account_data_events
20262026
)
20272027
)
20282028

2029-
ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral)
2029+
ephemeral = await sync_config.filter_collection.filter_room_ephemeral(
2030+
ephemeral
2031+
)
20302032

20312033
if not (
20322034
always_include

synapse/rest/admin/rooms.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,7 @@ class RoomEventContextServlet(RestServlet):
583583

584584
def __init__(self, hs: "HomeServer"):
585585
super().__init__()
586+
self._hs = hs
586587
self.clock = hs.get_clock()
587588
self.room_context_handler = hs.get_room_context_handler()
588589
self._event_serializer = hs.get_event_client_serializer()
@@ -600,7 +601,9 @@ async def on_GET(
600601
filter_str = parse_string(request, "filter", encoding="utf-8")
601602
if filter_str:
602603
filter_json = urlparse.unquote(filter_str)
603-
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
604+
event_filter: Optional[Filter] = Filter(
605+
self._hs, json_decoder.decode(filter_json)
606+
)
604607
else:
605608
event_filter = None
606609

synapse/rest/client/room.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ class RoomMessageListRestServlet(RestServlet):
550550

551551
def __init__(self, hs: "HomeServer"):
552552
super().__init__()
553+
self._hs = hs
553554
self.pagination_handler = hs.get_pagination_handler()
554555
self.auth = hs.get_auth()
555556
self.store = hs.get_datastore()
@@ -567,7 +568,9 @@ async def on_GET(
567568
filter_str = parse_string(request, "filter", encoding="utf-8")
568569
if filter_str:
569570
filter_json = urlparse.unquote(filter_str)
570-
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
571+
event_filter: Optional[Filter] = Filter(
572+
self._hs, json_decoder.decode(filter_json)
573+
)
571574
if (
572575
event_filter
573576
and event_filter.filter_json.get("event_format", "client")
@@ -672,6 +675,7 @@ class RoomEventContextServlet(RestServlet):
672675

673676
def __init__(self, hs: "HomeServer"):
674677
super().__init__()
678+
self._hs = hs
675679
self.clock = hs.get_clock()
676680
self.room_context_handler = hs.get_room_context_handler()
677681
self._event_serializer = hs.get_event_client_serializer()
@@ -688,7 +692,9 @@ async def on_GET(
688692
filter_str = parse_string(request, "filter", encoding="utf-8")
689693
if filter_str:
690694
filter_json = urlparse.unquote(filter_str)
691-
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
695+
event_filter: Optional[Filter] = Filter(
696+
self._hs, json_decoder.decode(filter_json)
697+
)
692698
else:
693699
event_filter = None
694700

0 commit comments

Comments
 (0)