From f2f45ba6ed5d9d8670811fec4a4cfbdd6742aeba Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 11:58:38 -0400 Subject: [PATCH 1/9] Support finding thread for relations of the root event. --- synapse/rest/client/receipts.py | 44 +++++++- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/relations.py | 59 +++++++++++ tests/storage/test_relations.py | 111 ++++++++++++++++++++ 4 files changed, 213 insertions(+), 2 deletions(-) create mode 100644 tests/storage/test_relations.py diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 14dec7ac4e79..b1053a6daac1 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -83,7 +83,7 @@ async def on_POST( ) # Ensure the event ID roughly correlates to the thread ID. - if thread_id != await self._main_store.get_thread_id(event_id): + if not await self._is_valid_thread_id(event_id, thread_id): raise SynapseError( 400, f"event_id {event_id} is not related to thread {thread_id}", @@ -109,6 +109,46 @@ async def on_POST( return 200, {} + async def _is_valid_thread_id(self, event_id: str, thread_id: str) -> bool: + """ + The thread ID provided must relate (in a vague sense) to the event ID. + + We check this to ensure clients aren't sending bogus receipts. + + A thread ID is considered valid if: + + 1. The event has a thread relation which matches the thread ID. + 2. The event has children events which form a thread relation (i.e. the + event is a thread root). + 2. The event is related to an event (recursively) which satisfies 1 or 2. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + It is valid to send a receipt for thread A on A, B, C, D, or E. + + It is valid to send a receipt for the main thread on A, D, and E. + + Args: + event_id: The event ID to check. + thread_id: The thread ID the event is potentially part of. + + Returns: + True if the event belongs to the given thread. + """ + + # If the receipt is on the main timeline, it is enough to check whether + # the event is directly related to a thread. + if thread_id == MAIN_TIMELINE: + return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id) + + # Otherwise, check if the event is directly part of a thread, or is the + # root message (or related to the root message) of a thread. + return thread_id == await self._main_store.get_thread_id_for_receipts(event_id) + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index b47fc606c754..ed0be4abe59c 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -245,6 +245,7 @@ def _invalidate_caches_for_event( self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7c54ce0b2e3d..5b87fa39fdd3 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -946,6 +946,11 @@ async def get_thread_id(self, event_id: str) -> str: Get the thread ID for an event. This considers multi-level relations, e.g. an annotation to an event which is part of a thread. + It only searches up the relations tree, i.e. it only considers events + which are the parent of the given event. + + See also get_thread_id_for_receipts. + Args: event_id: The event ID to fetch the thread ID for. @@ -953,6 +958,7 @@ async def get_thread_id(self, event_id: str) -> str: The event ID of the root event in the thread, if this event is part of a thread. "main", otherwise. """ + # Since event relations form a tree, we should only ever find 0 or 1 # results from the below query. sql = """ @@ -978,6 +984,59 @@ def _get_thread_id(txn: LoggingTransaction) -> str: return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + @cached() + async def get_thread_id_for_receipts(self, event_id: str) -> str: + """ + Get the thread ID for an event by traversing to the top-most related event + and confirming any children events form a thread. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + It considers A, B, C, D, and E as part of the thread. + + See also get_thread_id. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. "main", otherwise. + """ + + # Recurse up to the *root* node, then select relations of that to + # see if there are thread children. + sql = """ + SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type, 0 depth + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + WHERE depth <= 3 + ) SELECT relates_to_id FROM related_events ORDER BY depth DESC LIMIT 1 + ), ?) AND relation_type = 'm.thread' LIMIT 1; + """ + + def _get_related_thread_id(txn: LoggingTransaction) -> str: + txn.execute(sql, (event_id, event_id)) + row = txn.fetchone() + if row: + return row[0] + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE + + return await self.db_pool.runInteraction( + "get_related_thread_id", _get_related_thread_id + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py new file mode 100644 index 000000000000..a157de37e633 --- /dev/null +++ b/tests/storage/test_relations.py @@ -0,0 +1,111 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import MAIN_TIMELINE +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest + + +class RelationsStoreTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + """ + Creates a DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + F <--[m.annotation]-- G + + """ + self._main_store = self.hs.get_datastores().main + + self._create_relation("A", "B", "m.thread") + self._create_relation("B", "C", "m.annotation") + self._create_relation("A", "D", "m.reference") + self._create_relation("A", "E", "m.annotation") + self._create_relation("F", "G", "m.annotation") + + def _create_relation(self, parent_id: str, event_id: str, rel_type: str) -> None: + self.get_success( + self._main_store.db_pool.simple_insert( + table="event_relations", + values={ + "event_id": event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + }, + ) + ) + + def test_get_thread_id(self) -> None: + """ + Ensure that get_thread_id only searches up the tree for threads. + """ + # The thread itself and children of it return the thread. + thread_id = self.get_success(self._main_store.get_thread_id("B")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("C")) + self.assertEqual("A", thread_id) + + # But the root and events related to the root do not. + thread_id = self.get_success(self._main_store.get_thread_id("A")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("D")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("E")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + # Events which are not related to a thread at all should return the + # main timeline. + thread_id = self.get_success(self._main_store.get_thread_id("F")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("G")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + def test_get_thread_id_for_receipts(self) -> None: + """ + Ensure that get_thread_id_for_receipts searches up and down the tree for a thread. + """ + # All of the events are considered related to this thread. + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E")) + self.assertEqual("A", thread_id) + + # Events which are not related to a thread at all should return the + # main timeline. + thread_id = self.get_success(self._main_store.get_thread_id("F")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("G")) + self.assertEqual(MAIN_TIMELINE, thread_id) From 6c053c72af9cde3f14447820f69dc922f4faf5ae Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 12:10:06 -0400 Subject: [PATCH 2/9] Apply depth checks to simply query. --- synapse/storage/databases/main/relations.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 5b87fa39fdd3..66169d6193d3 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -963,18 +963,22 @@ async def get_thread_id(self, event_id: str) -> str: # results from the below query. sql = """ WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type + SELECT event_id, relates_to_id, relation_type, 0 depth FROM event_relations WHERE event_id = ? - UNION SELECT e.event_id, e.relates_to_id, e.relation_type + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 FROM event_relations e INNER JOIN related_events r ON r.relates_to_id = e.event_id - ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + WHERE relation_type = 'm.thread' + ORDER BY depth DESC + LIMIT 1; """ def _get_thread_id(txn: LoggingTransaction) -> str: txn.execute(sql, (event_id,)) - # TODO Should we ensure there's only a single result here? row = txn.fetchone() if row: return row[0] @@ -1020,7 +1024,10 @@ async def get_thread_id_for_receipts(self, event_id: str) -> str: FROM event_relations e INNER JOIN related_events r ON r.relates_to_id = e.event_id WHERE depth <= 3 - ) SELECT relates_to_id FROM related_events ORDER BY depth DESC LIMIT 1 + ) + SELECT relates_to_id FROM related_events + ORDER BY depth DESC + LIMIT 1 ), ?) AND relation_type = 'm.thread' LIMIT 1; """ From 4629fee40088371f05850ef4ee402f30c26893f9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 13 Oct 2022 13:03:09 -0400 Subject: [PATCH 3/9] Newsfragment --- changelog.d/14174.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/14174.feature diff --git a/changelog.d/14174.feature b/changelog.d/14174.feature new file mode 100644 index 000000000000..5d0ae16e131b --- /dev/null +++ b/changelog.d/14174.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). From 8ea1c82190c496e5172ff3d301fcf162b740d827 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 10:44:58 -0400 Subject: [PATCH 4/9] Clarify wording. Co-authored-by: David Robertson --- synapse/rest/client/receipts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index b1053a6daac1..de68091884af 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -137,7 +137,7 @@ async def _is_valid_thread_id(self, event_id: str, thread_id: str) -> bool: thread_id: The thread ID the event is potentially part of. Returns: - True if the event belongs to the given thread. + True if the event belongs to the given thread, otherwise False. """ # If the receipt is on the main timeline, it is enough to check whether From 931ff67457af6aee70bd993d84c291cf3240900d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 11:40:47 -0400 Subject: [PATCH 5/9] main thread -> main timeline --- synapse/rest/client/receipts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index de68091884af..58141e006299 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -130,7 +130,7 @@ async def _is_valid_thread_id(self, event_id: str, thread_id: str) -> bool: It is valid to send a receipt for thread A on A, B, C, D, or E. - It is valid to send a receipt for the main thread on A, D, and E. + It is valid to send a receipt for the main timeline on A, D, and E. Args: event_id: The event ID to check. From 0baa609af794f87e4d2fb729c532414f2c4b428f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 11:42:05 -0400 Subject: [PATCH 6/9] Fix test data. --- tests/storage/test_relations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py index a157de37e633..cd1d00208b69 100644 --- a/tests/storage/test_relations.py +++ b/tests/storage/test_relations.py @@ -38,7 +38,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._create_relation("A", "B", "m.thread") self._create_relation("B", "C", "m.annotation") self._create_relation("A", "D", "m.reference") - self._create_relation("A", "E", "m.annotation") + self._create_relation("D", "E", "m.annotation") self._create_relation("F", "G", "m.annotation") def _create_relation(self, parent_id: str, event_id: str, rel_type: str) -> None: From f9963354a499dcd121c503536b3122b875352678 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 11:44:15 -0400 Subject: [PATCH 7/9] Review comments. --- synapse/rest/client/receipts.py | 9 ++++----- synapse/storage/databases/main/relations.py | 3 ++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 58141e006299..e7f7f205f464 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -83,7 +83,7 @@ async def on_POST( ) # Ensure the event ID roughly correlates to the thread ID. - if not await self._is_valid_thread_id(event_id, thread_id): + if not await self._is_event_in_thread(event_id, thread_id): raise SynapseError( 400, f"event_id {event_id} is not related to thread {thread_id}", @@ -109,11 +109,10 @@ async def on_POST( return 200, {} - async def _is_valid_thread_id(self, event_id: str, thread_id: str) -> bool: + async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool: """ - The thread ID provided must relate (in a vague sense) to the event ID. - - We check this to ensure clients aren't sending bogus receipts. + The event must be related to the thread ID (in a vague sense) to ensure + clients aren't sending bogus receipts. A thread ID is considered valid if: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 66169d6193d3..5432ae2299bd 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -1000,7 +1000,8 @@ async def get_thread_id_for_receipts(self, event_id: str) -> str: ^ |--[m.reference]-- D <--[m.annotation]-- E - It considers A, B, C, D, and E as part of the thread. + get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part + of thread A. See also get_thread_id. From d93eaafe7c12e4fd566db20d3a4e3aa808845295 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 13:17:36 -0400 Subject: [PATCH 8/9] Clarifications from review. --- synapse/rest/client/receipts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index e7f7f205f464..52b7940e059a 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -119,7 +119,8 @@ async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool: 1. The event has a thread relation which matches the thread ID. 2. The event has children events which form a thread relation (i.e. the event is a thread root). - 2. The event is related to an event (recursively) which satisfies 1 or 2. + 3. The event is related to an event (recursively, up to a point) which + satisfies 1 or 2. Given the following DAG: From ddd1e3b290df8941c0eabdf2dcbae10d5edd855c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 13:30:10 -0400 Subject: [PATCH 9/9] Additional clarifications. --- synapse/rest/client/receipts.py | 10 +++---- synapse/storage/databases/main/relations.py | 31 +++++++++++++++++---- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 52b7940e059a..18a282b22cf9 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -114,12 +114,12 @@ async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool: The event must be related to the thread ID (in a vague sense) to ensure clients aren't sending bogus receipts. - A thread ID is considered valid if: + A thread ID is considered valid for a given event E if: - 1. The event has a thread relation which matches the thread ID. - 2. The event has children events which form a thread relation (i.e. the - event is a thread root). - 3. The event is related to an event (recursively, up to a point) which + 1. E has a thread relation which matches the thread ID; + 2. E has another event which has a thread relation to E matching the + thread ID; or + 3. E is recursively related (via any rel_type) to an event which satisfies 1 or 2. Given the following DAG: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 5432ae2299bd..1de62ee9dfba 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -946,8 +946,17 @@ async def get_thread_id(self, event_id: str) -> str: Get the thread ID for an event. This considers multi-level relations, e.g. an annotation to an event which is part of a thread. - It only searches up the relations tree, i.e. it only considers events - which are the parent of the given event. + It only searches up the relations tree, i.e. it only searches for events + which the given event is related to (and which those events are related + to, etc.) + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id(X) considers events B and C as part of thread A. See also get_thread_id_for_receipts. @@ -959,8 +968,13 @@ async def get_thread_id(self, event_id: str) -> str: of a thread. "main", otherwise. """ - # Since event relations form a tree, we should only ever find 0 or 1 - # results from the below query. + # Recurse event relations up to the *root* event, then search that chain + # of relations for a thread relation. If one is found, the root event is + # returned. + # + # Note that this should only ever find 0 or 1 entries since it is invalid + # for an event to have a thread relation to an event which also has a + # relation. sql = """ WITH RECURSIVE related_events AS ( SELECT event_id, relates_to_id, relation_type, 0 depth @@ -1013,8 +1027,13 @@ async def get_thread_id_for_receipts(self, event_id: str) -> str: of a thread. "main", otherwise. """ - # Recurse up to the *root* node, then select relations of that to - # see if there are thread children. + # Recurse event relations up to the *root* event, then search for any events + # related to that root node for a thread relation. If one is found, the + # root event is returned. + # + # Note that there cannot be thread relations in the middle of the chain since + # it is invalid for an event to have a thread relation to an event which also + # has a relation. sql = """ SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( WITH RECURSIVE related_events AS (