1
1
# Copyright 2015, 2016 OpenMarket Ltd
2
2
# Copyright 2017 Vector Creations Ltd
3
3
# Copyright 2018-2019 New Vector Ltd
4
- # Copyright 2019 The Matrix.org Foundation C.I.C.
4
+ # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
5
5
#
6
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
7
# you may not use this file except in compliance with the License.
86
86
# cf https://github.com/matrix-org/matrix-doc/pull/2326
87
87
"org.matrix.labels" : {"type" : "array" , "items" : {"type" : "string" }},
88
88
"org.matrix.not_labels" : {"type" : "array" , "items" : {"type" : "string" }},
89
+ # MSC3440, filtering by event relations.
90
+ "io.element.relation_senders" : {"type" : "array" , "items" : {"type" : "string" }},
91
+ "io.element.relation_types" : {"type" : "array" , "items" : {"type" : "string" }},
89
92
},
90
93
}
91
94
@@ -146,14 +149,16 @@ def matrix_user_id_validator(user_id_str: str) -> UserID:
146
149
147
150
class Filtering :
148
151
def __init__ (self , hs : "HomeServer" ):
149
- super (). __init__ ()
152
+ self . _hs = hs
150
153
self .store = hs .get_datastore ()
151
154
155
+ self .DEFAULT_FILTER_COLLECTION = FilterCollection (hs , {})
156
+
152
157
async def get_user_filter (
153
158
self , user_localpart : str , filter_id : Union [int , str ]
154
159
) -> "FilterCollection" :
155
160
result = await self .store .get_user_filter (user_localpart , filter_id )
156
- return FilterCollection (result )
161
+ return FilterCollection (self . _hs , result )
157
162
158
163
def add_user_filter (
159
164
self , user_localpart : str , user_filter : JsonDict
@@ -191,21 +196,22 @@ def check_valid_filter(self, user_filter_json: JsonDict) -> None:
191
196
192
197
193
198
class FilterCollection :
194
- def __init__ (self , filter_json : JsonDict ):
199
+ def __init__ (self , hs : "HomeServer" , filter_json : JsonDict ):
195
200
self ._filter_json = filter_json
196
201
197
202
room_filter_json = self ._filter_json .get ("room" , {})
198
203
199
204
self ._room_filter = Filter (
200
- {k : v for k , v in room_filter_json .items () if k in ("rooms" , "not_rooms" )}
205
+ hs ,
206
+ {k : v for k , v in room_filter_json .items () if k in ("rooms" , "not_rooms" )},
201
207
)
202
208
203
- self ._room_timeline_filter = Filter (room_filter_json .get ("timeline" , {}))
204
- self ._room_state_filter = Filter (room_filter_json .get ("state" , {}))
205
- self ._room_ephemeral_filter = Filter (room_filter_json .get ("ephemeral" , {}))
206
- self ._room_account_data = Filter (room_filter_json .get ("account_data" , {}))
207
- self ._presence_filter = Filter (filter_json .get ("presence" , {}))
208
- self ._account_data = Filter (filter_json .get ("account_data" , {}))
209
+ self ._room_timeline_filter = Filter (hs , room_filter_json .get ("timeline" , {}))
210
+ self ._room_state_filter = Filter (hs , room_filter_json .get ("state" , {}))
211
+ self ._room_ephemeral_filter = Filter (hs , room_filter_json .get ("ephemeral" , {}))
212
+ self ._room_account_data = Filter (hs , room_filter_json .get ("account_data" , {}))
213
+ self ._presence_filter = Filter (hs , filter_json .get ("presence" , {}))
214
+ self ._account_data = Filter (hs , filter_json .get ("account_data" , {}))
209
215
210
216
self .include_leave = filter_json .get ("room" , {}).get ("include_leave" , False )
211
217
self .event_fields = filter_json .get ("event_fields" , [])
@@ -232,25 +238,37 @@ def lazy_load_members(self) -> bool:
232
238
def include_redundant_members (self ) -> bool :
233
239
return self ._room_state_filter .include_redundant_members
234
240
235
- def filter_presence (
241
+ async def filter_presence (
236
242
self , events : Iterable [UserPresenceState ]
237
243
) -> List [UserPresenceState ]:
238
- return self ._presence_filter .filter (events )
244
+ return await self ._presence_filter .filter (events )
239
245
240
- def filter_account_data (self , events : Iterable [JsonDict ]) -> List [JsonDict ]:
241
- return self ._account_data .filter (events )
246
+ async def filter_account_data (self , events : Iterable [JsonDict ]) -> List [JsonDict ]:
247
+ return await self ._account_data .filter (events )
242
248
243
- def filter_room_state (self , events : Iterable [EventBase ]) -> List [EventBase ]:
244
- return self ._room_state_filter .filter (self ._room_filter .filter (events ))
249
+ async def filter_room_state (self , events : Iterable [EventBase ]) -> List [EventBase ]:
250
+ return await self ._room_state_filter .filter (
251
+ await self ._room_filter .filter (events )
252
+ )
245
253
246
- def filter_room_timeline (self , events : Iterable [EventBase ]) -> List [EventBase ]:
247
- return self ._room_timeline_filter .filter (self ._room_filter .filter (events ))
254
+ async def filter_room_timeline (
255
+ self , events : Iterable [EventBase ]
256
+ ) -> List [EventBase ]:
257
+ return await self ._room_timeline_filter .filter (
258
+ await self ._room_filter .filter (events )
259
+ )
248
260
249
- def filter_room_ephemeral (self , events : Iterable [JsonDict ]) -> List [JsonDict ]:
250
- return self ._room_ephemeral_filter .filter (self ._room_filter .filter (events ))
261
+ async def filter_room_ephemeral (self , events : Iterable [JsonDict ]) -> List [JsonDict ]:
262
+ return await self ._room_ephemeral_filter .filter (
263
+ await self ._room_filter .filter (events )
264
+ )
251
265
252
- def filter_room_account_data (self , events : Iterable [JsonDict ]) -> List [JsonDict ]:
253
- return self ._room_account_data .filter (self ._room_filter .filter (events ))
266
+ async def filter_room_account_data (
267
+ self , events : Iterable [JsonDict ]
268
+ ) -> List [JsonDict ]:
269
+ return await self ._room_account_data .filter (
270
+ await self ._room_filter .filter (events )
271
+ )
254
272
255
273
def blocks_all_presence (self ) -> bool :
256
274
return (
@@ -274,7 +292,9 @@ def blocks_all_room_timeline(self) -> bool:
274
292
275
293
276
294
class Filter :
277
- def __init__ (self , filter_json : JsonDict ):
295
+ def __init__ (self , hs : "HomeServer" , filter_json : JsonDict ):
296
+ self ._hs = hs
297
+ self ._store = hs .get_datastore ()
278
298
self .filter_json = filter_json
279
299
280
300
self .limit = filter_json .get ("limit" , 10 )
@@ -297,6 +317,20 @@ def __init__(self, filter_json: JsonDict):
297
317
self .labels = filter_json .get ("org.matrix.labels" , None )
298
318
self .not_labels = filter_json .get ("org.matrix.not_labels" , [])
299
319
320
+ # Ideally these would be rejected at the endpoint if they were provided
321
+ # and not supported, but that would involve modifying the JSON schema
322
+ # based on the homeserver configuration.
323
+ if hs .config .experimental .msc3440_enabled :
324
+ self .relation_senders = self .filter_json .get (
325
+ "io.element.relation_senders" , None
326
+ )
327
+ self .relation_types = self .filter_json .get (
328
+ "io.element.relation_types" , None
329
+ )
330
+ else :
331
+ self .relation_senders = None
332
+ self .relation_types = None
333
+
300
334
def filters_all_types (self ) -> bool :
301
335
return "*" in self .not_types
302
336
@@ -306,7 +340,7 @@ def filters_all_senders(self) -> bool:
306
340
def filters_all_rooms (self ) -> bool :
307
341
return "*" in self .not_rooms
308
342
309
- def check (self , event : FilterEvent ) -> bool :
343
+ def _check (self , event : FilterEvent ) -> bool :
310
344
"""Checks whether the filter matches the given event.
311
345
312
346
Args:
@@ -420,8 +454,30 @@ def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
420
454
421
455
return room_ids
422
456
423
- def filter (self , events : Iterable [FilterEvent ]) -> List [FilterEvent ]:
424
- return list (filter (self .check , events ))
457
+ async def _check_event_relations (
458
+ self , events : Iterable [FilterEvent ]
459
+ ) -> List [FilterEvent ]:
460
+ # The event IDs to check, mypy doesn't understand the ifinstance check.
461
+ event_ids = [event .event_id for event in events if isinstance (event , EventBase )] # type: ignore[attr-defined]
462
+ event_ids_to_keep = set (
463
+ await self ._store .events_have_relations (
464
+ event_ids , self .relation_senders , self .relation_types
465
+ )
466
+ )
467
+
468
+ return [
469
+ event
470
+ for event in events
471
+ if not isinstance (event , EventBase ) or event .event_id in event_ids_to_keep
472
+ ]
473
+
474
+ async def filter (self , events : Iterable [FilterEvent ]) -> List [FilterEvent ]:
475
+ result = [event for event in events if self ._check (event )]
476
+
477
+ if self .relation_senders or self .relation_types :
478
+ return await self ._check_event_relations (result )
479
+
480
+ return result
425
481
426
482
def with_room_ids (self , room_ids : Iterable [str ]) -> "Filter" :
427
483
"""Returns a new filter with the given room IDs appended.
@@ -433,7 +489,7 @@ def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
433
489
filter: A new filter including the given rooms and the old
434
490
filter's rooms.
435
491
"""
436
- newFilter = Filter (self .filter_json )
492
+ newFilter = Filter (self ._hs , self . filter_json )
437
493
newFilter .rooms += room_ids
438
494
return newFilter
439
495
@@ -444,6 +500,3 @@ def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
444
500
return actual_value .startswith (type_prefix )
445
501
else :
446
502
return actual_value == filter_value
447
-
448
-
449
- DEFAULT_FILTER_COLLECTION = FilterCollection ({})
0 commit comments