Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add receipt types constant #11531

Merged
merged 7 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11531.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a receipt types constant for `m.read`.
4 changes: 4 additions & 0 deletions synapse/api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,5 +253,9 @@ class GuestAccess:
FORBIDDEN: Final = "forbidden"


class ReceiptTypes:
READ: Final = "m.read"


class ReadReceiptEventFields:
MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
6 changes: 3 additions & 3 deletions synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple

from synapse.api.constants import ReadReceiptEventFields
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.appservice import ApplicationService
from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
Expand Down Expand Up @@ -178,7 +178,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:

for event_id in content.keys():
event_content = content.get(event_id, {})
m_read = event_content.get("m.read", {})
m_read = event_content.get(ReceiptTypes.READ, {})

# If m_read is missing copy over the original event_content as there is nothing to process here
if not m_read:
Expand Down Expand Up @@ -206,7 +206,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:

# Set new users unless empty
if len(new_users.keys()) > 0:
new_event["content"][event_id] = {"m.read": new_users}
new_event["content"][event_id] = {ReceiptTypes.READ: new_users}

# Append new_event to visible_events unless empty
if len(new_event["content"].keys()) > 0:
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import attr
from prometheus_client import Counter

from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
Expand Down Expand Up @@ -1046,7 +1046,7 @@ async def unread_notifs_for_room_id(
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
receipt_type="m.read",
receipt_type=ReceiptTypes.READ,
)

notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
Expand Down
3 changes: 2 additions & 1 deletion synapse/push/push_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Dict

from synapse.api.constants import ReceiptTypes
from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
Expand All @@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)

my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ)

badge = len(invites)

Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/client/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import ReceiptTypes
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
Expand Down Expand Up @@ -54,7 +55,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
)

receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
user_id, "m.read"
user_id, ReceiptTypes.READ
)

notif_event_ids = [pa["event_id"] for pa in push_actions]
Expand Down
6 changes: 3 additions & 3 deletions synapse/rest/client/read_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import ReadReceiptEventFields
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
Expand Down Expand Up @@ -48,7 +48,7 @@ async def on_POST(
await self.presence_handler.bump_presence_active_time(requester.user)

body = parse_json_object_from_request(request)
read_event_id = body.get("m.read", None)
read_event_id = body.get(ReceiptTypes.READ, None)
hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)

if not isinstance(hidden, bool):
Expand All @@ -62,7 +62,7 @@ async def on_POST(
if read_event_id:
await self.receipts_handler.received_client_receipt(
room_id,
"m.read",
ReceiptTypes.READ,
user_id=requester.user.to_string(),
event_id=read_event_id,
hidden=hidden,
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import re
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import ReadReceiptEventFields
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http import get_request_user_agent
from synapse.http.server import HttpServer
Expand Down Expand Up @@ -53,7 +53,7 @@ async def on_POST(
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

if receipt_type != "m.read":
if receipt_type != ReceiptTypes.READ:
raise SynapseError(400, "Receipt type must be 'm.read'")

# Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
Expand Down
101 changes: 68 additions & 33 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,25 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
)

from twisted.internet import defer

from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
Expand Down Expand Up @@ -78,17 +89,13 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)

def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream

Returns:
int
"""
def get_max_receipt_stream_id(self) -> int:
"""Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()

@cached()
async def get_users_with_read_receipts_in_room(self, room_id):
receipts = await self.get_receipts_for_room(room_id, "m.read")
async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
return {r["user_id"] for r in receipts}

@cached(num_args=2)
Expand Down Expand Up @@ -119,7 +126,9 @@ async def get_last_receipt_event_id_for_user(
)

@cached(num_args=2)
async def get_receipts_for_user(self, user_id, receipt_type):
async def get_receipts_for_user(
self, user_id: str, receipt_type: str
) -> Dict[str, str]:
rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
Expand All @@ -129,8 +138,10 @@ async def get_receipts_for_user(self, user_id, receipt_type):

return {row["room_id"]: row["event_id"] for row in rows}

async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
async def get_receipts_for_user_with_orderings(
self, user_id: str, receipt_type: str
) -> JsonDict:
def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
sql = (
"SELECT rl.room_id, rl.event_id,"
" e.topological_ordering, e.stream_ordering"
Expand Down Expand Up @@ -209,10 +220,10 @@ async def get_linearized_receipts_for_room(
@cached(num_args=3, tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
) -> List[JsonDict]:
"""See get_linearized_receipts_for_room"""

def f(txn):
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
Expand Down Expand Up @@ -250,11 +261,13 @@ def f(txn):
list_name="room_ids",
num_args=3,
)
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
) -> Dict[str, List[JsonDict]]:
if not room_ids:
return {}

def f(txn):
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
Expand Down Expand Up @@ -323,7 +336,7 @@ async def get_linearized_receipts_for_all_rooms(
A dictionary of roomids to a list of receipts.
"""

def f(txn):
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
Expand Down Expand Up @@ -379,7 +392,7 @@ async def get_users_sent_receipts_between(
if last_id == current_id:
return defer.succeed([])

def _get_users_sent_receipts_between_txn(txn):
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT user_id FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
Expand Down Expand Up @@ -419,7 +432,9 @@ async def get_all_updated_receipts(
if last_id == current_id:
return [], current_id, False

def get_all_updated_receipts_txn(txn):
def get_all_updated_receipts_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized
Expand All @@ -446,8 +461,8 @@ def get_all_updated_receipts_txn(txn):

def _invalidate_get_users_with_receipts_in_room(
self, room_id: str, receipt_type: str, user_id: str
):
if receipt_type != "m.read":
) -> None:
if receipt_type != ReceiptTypes.READ:
return

res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
Expand All @@ -461,7 +476,9 @@ def _invalidate_get_users_with_receipts_in_room(

self.get_users_with_read_receipts_in_room.invalidate((room_id,))

def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
def invalidate_caches_for_receipt(
self, room_id: str, receipt_type: str, user_id: str
) -> None:
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
Expand All @@ -482,11 +499,18 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
return super().process_replication_rows(stream_name, instance_name, token, rows)

def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
):
self,
txn: LoggingTransaction,
room_id: str,
receipt_type: str,
user_id: str,
event_id: str,
data: JsonDict,
stream_id: int,
) -> Optional[int]:
"""Inserts a read-receipt into the database if it's newer than the current RR

Returns: int|None
Returns:
None if the RR is older than the current RR
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
Expand Down Expand Up @@ -550,7 +574,7 @@ def insert_linearized_receipt_txn(
lock=False,
)

if receipt_type == "m.read" and stream_ordering is not None:
if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
Expand Down Expand Up @@ -580,7 +604,7 @@ async def insert_receipt(
else:
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn):
def graph_to_linear(txn: LoggingTransaction) -> str:
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
Expand Down Expand Up @@ -634,11 +658,16 @@ def graph_to_linear(txn):
return stream_id, max_persisted_id

async def insert_graph_receipt(
self, room_id, receipt_type, user_id, event_ids, data
):
self,
room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts

return await self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
Expand All @@ -649,8 +678,14 @@ async def insert_graph_receipt(
)

def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data
):
self,
txn: LoggingTransaction,
room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts

txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
Expand Down