Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add a constant for receipt types (m.read). (#11531)
Browse files Browse the repository at this point in the history
And expand some type hints in the receipts storage module.
  • Loading branch information
clokep authored Dec 8, 2021
1 parent 7ecaa3b commit d93362d
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 45 deletions.
1 change: 1 addition & 0 deletions changelog.d/11531.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a receipt types constant for `m.read`.
4 changes: 4 additions & 0 deletions synapse/api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,5 +253,9 @@ class GuestAccess:
FORBIDDEN: Final = "forbidden"


class ReceiptTypes:
READ: Final = "m.read"


class ReadReceiptEventFields:
MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
6 changes: 3 additions & 3 deletions synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple

from synapse.api.constants import ReadReceiptEventFields
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.appservice import ApplicationService
from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
Expand Down Expand Up @@ -178,7 +178,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:

for event_id in content.keys():
event_content = content.get(event_id, {})
m_read = event_content.get("m.read", {})
m_read = event_content.get(ReceiptTypes.READ, {})

# If m_read is missing copy over the original event_content as there is nothing to process here
if not m_read:
Expand Down Expand Up @@ -206,7 +206,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:

# Set new users unless empty
if len(new_users.keys()) > 0:
new_event["content"][event_id] = {"m.read": new_users}
new_event["content"][event_id] = {ReceiptTypes.READ: new_users}

# Append new_event to visible_events unless empty
if len(new_event["content"].keys()) > 0:
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import attr
from prometheus_client import Counter

from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
Expand Down Expand Up @@ -1046,7 +1046,7 @@ async def unread_notifs_for_room_id(
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
receipt_type="m.read",
receipt_type=ReceiptTypes.READ,
)

notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
Expand Down
3 changes: 2 additions & 1 deletion synapse/push/push_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Dict

from synapse.api.constants import ReceiptTypes
from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
Expand All @@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)

my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ)

badge = len(invites)

Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/client/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import ReceiptTypes
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
Expand Down Expand Up @@ -54,7 +55,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
)

receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
user_id, "m.read"
user_id, ReceiptTypes.READ
)

notif_event_ids = [pa["event_id"] for pa in push_actions]
Expand Down
6 changes: 3 additions & 3 deletions synapse/rest/client/read_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import ReadReceiptEventFields
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
Expand Down Expand Up @@ -48,7 +48,7 @@ async def on_POST(
await self.presence_handler.bump_presence_active_time(requester.user)

body = parse_json_object_from_request(request)
read_event_id = body.get("m.read", None)
read_event_id = body.get(ReceiptTypes.READ, None)
hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)

if not isinstance(hidden, bool):
Expand All @@ -62,7 +62,7 @@ async def on_POST(
if read_event_id:
await self.receipts_handler.received_client_receipt(
room_id,
"m.read",
ReceiptTypes.READ,
user_id=requester.user.to_string(),
event_id=read_event_id,
hidden=hidden,
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import re
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import ReadReceiptEventFields
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http import get_request_user_agent
from synapse.http.server import HttpServer
Expand Down Expand Up @@ -53,7 +53,7 @@ async def on_POST(
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

if receipt_type != "m.read":
if receipt_type != ReceiptTypes.READ:
raise SynapseError(400, "Receipt type must be 'm.read'")

# Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
Expand Down
101 changes: 68 additions & 33 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,25 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
)

from twisted.internet import defer

from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
Expand Down Expand Up @@ -78,17 +89,13 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)

def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream
Returns:
int
"""
def get_max_receipt_stream_id(self) -> int:
"""Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()

@cached()
async def get_users_with_read_receipts_in_room(self, room_id):
receipts = await self.get_receipts_for_room(room_id, "m.read")
async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
return {r["user_id"] for r in receipts}

@cached(num_args=2)
Expand Down Expand Up @@ -119,7 +126,9 @@ async def get_last_receipt_event_id_for_user(
)

@cached(num_args=2)
async def get_receipts_for_user(self, user_id, receipt_type):
async def get_receipts_for_user(
self, user_id: str, receipt_type: str
) -> Dict[str, str]:
rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
Expand All @@ -129,8 +138,10 @@ async def get_receipts_for_user(self, user_id, receipt_type):

return {row["room_id"]: row["event_id"] for row in rows}

async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
async def get_receipts_for_user_with_orderings(
self, user_id: str, receipt_type: str
) -> JsonDict:
def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
sql = (
"SELECT rl.room_id, rl.event_id,"
" e.topological_ordering, e.stream_ordering"
Expand Down Expand Up @@ -209,10 +220,10 @@ async def get_linearized_receipts_for_room(
@cached(num_args=3, tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
) -> List[JsonDict]:
"""See get_linearized_receipts_for_room"""

def f(txn):
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
Expand Down Expand Up @@ -250,11 +261,13 @@ def f(txn):
list_name="room_ids",
num_args=3,
)
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
) -> Dict[str, List[JsonDict]]:
if not room_ids:
return {}

def f(txn):
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
Expand Down Expand Up @@ -323,7 +336,7 @@ async def get_linearized_receipts_for_all_rooms(
A dictionary of roomids to a list of receipts.
"""

def f(txn):
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
Expand Down Expand Up @@ -379,7 +392,7 @@ async def get_users_sent_receipts_between(
if last_id == current_id:
return defer.succeed([])

def _get_users_sent_receipts_between_txn(txn):
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
SELECT DISTINCT user_id FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
Expand Down Expand Up @@ -419,7 +432,9 @@ async def get_all_updated_receipts(
if last_id == current_id:
return [], current_id, False

def get_all_updated_receipts_txn(txn):
def get_all_updated_receipts_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
sql = """
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized
Expand All @@ -446,8 +461,8 @@ def get_all_updated_receipts_txn(txn):

def _invalidate_get_users_with_receipts_in_room(
self, room_id: str, receipt_type: str, user_id: str
):
if receipt_type != "m.read":
) -> None:
if receipt_type != ReceiptTypes.READ:
return

res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
Expand All @@ -461,7 +476,9 @@ def _invalidate_get_users_with_receipts_in_room(

self.get_users_with_read_receipts_in_room.invalidate((room_id,))

def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
def invalidate_caches_for_receipt(
self, room_id: str, receipt_type: str, user_id: str
) -> None:
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
Expand All @@ -482,11 +499,18 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
return super().process_replication_rows(stream_name, instance_name, token, rows)

def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
):
self,
txn: LoggingTransaction,
room_id: str,
receipt_type: str,
user_id: str,
event_id: str,
data: JsonDict,
stream_id: int,
) -> Optional[int]:
"""Inserts a read-receipt into the database if it's newer than the current RR
Returns: int|None
Returns:
None if the RR is older than the current RR
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
Expand Down Expand Up @@ -550,7 +574,7 @@ def insert_linearized_receipt_txn(
lock=False,
)

if receipt_type == "m.read" and stream_ordering is not None:
if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
Expand Down Expand Up @@ -580,7 +604,7 @@ async def insert_receipt(
else:
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn):
def graph_to_linear(txn: LoggingTransaction) -> str:
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
Expand Down Expand Up @@ -634,11 +658,16 @@ def graph_to_linear(txn):
return stream_id, max_persisted_id

async def insert_graph_receipt(
self, room_id, receipt_type, user_id, event_ids, data
):
self,
room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts

return await self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
Expand All @@ -649,8 +678,14 @@ async def insert_graph_receipt(
)

def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data
):
self,
txn: LoggingTransaction,
room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts

txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
Expand Down

0 comments on commit d93362d

Please sign in to comment.