Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Support stable and unstable filtering by relations.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Mar 7, 2022
1 parent 3d9e13b commit ae46d1b
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 37 deletions.
21 changes: 13 additions & 8 deletions synapse/api/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
# MSC3440, filtering by event relations.
"related_by_senders": {"type": "array", "items": {"type": "string"}},
"io.element.relation_senders": {"type": "array", "items": {"type": "string"}},
"related_by_rel_types": {"type": "array", "items": {"type": "string"}},
"io.element.relation_types": {"type": "array", "items": {"type": "string"}},
},
}
Expand Down Expand Up @@ -322,15 +324,18 @@ def __init__(self, hs: "HomeServer", filter_json: JsonDict):
# and not supported, but that would involve modifying the JSON schema
# based on the homeserver configuration.
if hs.config.experimental.msc3440_enabled:
self.relation_senders = self.filter_json.get(
"io.element.relation_senders", None
# Fallback to the unstable prefix if the stable version is not given.
self.related_by_senders = self.filter_json.get(
"related_by_senders",
self.filter_json.get("io.element.relation_senders", None),
)
self.relation_types = self.filter_json.get(
"io.element.relation_types", None
self.related_by_rel_types = self.filter_json.get(
"related_by_rel_types",
self.filter_json.get("io.element.relation_types", None),
)
else:
self.relation_senders = None
self.relation_types = None
self.related_by_senders = None
self.related_by_rel_types = None

def filters_all_types(self) -> bool:
return "*" in self.not_types
Expand Down Expand Up @@ -461,7 +466,7 @@ async def _check_event_relations(
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
event_ids_to_keep = set(
await self._store.events_have_relations(
event_ids, self.relation_senders, self.relation_types
event_ids, self.related_by_senders, self.related_by_rel_types
)
)

Expand All @@ -474,7 +479,7 @@ async def _check_event_relations(
async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
result = [event for event in events if self._check(event)]

if self.relation_senders or self.relation_types:
if self.related_by_senders or self.related_by_rel_types:
return await self._check_event_relations(result)

return result
Expand Down
18 changes: 10 additions & 8 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,21 +325,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
args.extend(event_filter.labels)

# Filter on relation_senders / relation types from the joined tables.
if event_filter.relation_senders:
if event_filter.related_by_senders:
clauses.append(
"(%s)"
% " OR ".join(
"related_event.sender = ?" for _ in event_filter.relation_senders
"related_event.sender = ?" for _ in event_filter.related_by_senders
)
)
args.extend(event_filter.relation_senders)
args.extend(event_filter.related_by_senders)

if event_filter.relation_types:
if event_filter.related_by_rel_types:
clauses.append(
"(%s)"
% " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
% " OR ".join(
"relation_type = ?" for _ in event_filter.related_by_rel_types
)
)
args.extend(event_filter.relation_types)
args.extend(event_filter.related_by_rel_types)

return " AND ".join(clauses), args

Expand Down Expand Up @@ -1203,15 +1205,15 @@ def _paginate_room_events_txn(
# If there is a filter on relation_senders and relation_types join to the
# relations table.
if event_filter and (
event_filter.relation_senders or event_filter.relation_types
event_filter.related_by_senders or event_filter.related_by_rel_types
):
# Filtering by relations could cause the same event to appear multiple
# times (since there's no limit on the number of relations to an event).
needs_distinct = True
join_clause += """
LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
"""
if event_filter.relation_senders:
if event_filter.related_by_senders:
join_clause += """
LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
"""
Expand Down
18 changes: 8 additions & 10 deletions tests/rest/client/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2141,21 +2141,19 @@ def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:

def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to.
filter = {"io.element.relation_senders": [self.second_user_id]}
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)

# Messages which third user reacted to.
filter = {"io.element.relation_senders": [self.third_user_id]}
filter = {"related_by_senders": [self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)

# Messages which either user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id, self.third_user_id]
}
filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 2, chunk)
self.assertCountEqual(
Expand All @@ -2164,20 +2162,20 @@ def test_filter_relation_senders(self) -> None:

def test_filter_relation_type(self) -> None:
# Messages which have annotations.
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)

# Messages which have references.
filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)

# Messages which have either annotations or references.
filter = {
"io.element.relation_types": [
"related_by_rel_types": [
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
]
Expand All @@ -2191,8 +2189,8 @@ def test_filter_relation_type(self) -> None:
def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id],
"io.element.relation_types": [RelationTypes.ANNOTATION],
"related_by_senders": [self.second_user_id],
"related_by_rel_types": [RelationTypes.ANNOTATION],
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
Expand Down
20 changes: 9 additions & 11 deletions tests/storage/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,19 @@ def _filter_messages(self, filter: JsonDict) -> List[EventBase]:

def test_filter_relation_senders(self):
# Messages which second user reacted to.
filter = {"io.element.relation_senders": [self.second_user_id]}
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_1)

# Messages which third user reacted to.
filter = {"io.element.relation_senders": [self.third_user_id]}
filter = {"related_by_senders": [self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_2)

# Messages which either user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id, self.third_user_id]
}
filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 2, chunk)
self.assertCountEqual(
Expand All @@ -152,20 +150,20 @@ def test_filter_relation_senders(self):

def test_filter_relation_type(self):
# Messages which have annotations.
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_1)

# Messages which have references.
filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_2)

# Messages which have either annotations or references.
filter = {
"io.element.relation_types": [
"related_by_rel_types": [
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
]
Expand All @@ -179,8 +177,8 @@ def test_filter_relation_type(self):
def test_filter_relation_senders_and_type(self):
# Messages which second user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id],
"io.element.relation_types": [RelationTypes.ANNOTATION],
"related_by_senders": [self.second_user_id],
"related_by_rel_types": [RelationTypes.ANNOTATION],
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
Expand All @@ -201,7 +199,7 @@ def test_duplicate_relation(self):
tok=self.second_tok,
)

filter = {"io.element.relation_senders": [self.second_user_id]}
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_1)

0 comments on commit ae46d1b

Please sign in to comment.