Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add types to StreamToken and RoomStreamToken #8279

Merged
merged 4 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8279.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `StreamToken` and `RoomStreamToken` classes.
5 changes: 2 additions & 3 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,12 +1310,11 @@ async def _generate_sync_entry_for_presence(
presence_source = self.event_sources.sources["presence"]

since_token = sync_result_builder.since_token
presence_key = None
include_offline = False
if since_token and not sync_result_builder.full_state:
presence_key = since_token.presence_key
include_offline = True
else:
presence_key = None
include_offline = False

presence, presence_key = await presence_source.get_new_events(
user=user,
Expand Down
7 changes: 3 additions & 4 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]
}

async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
self, from_key: int, user_ids: Iterable[str]
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Expand All @@ -499,7 +499,6 @@ async def get_users_whose_devices_changed(
Returns:
The set of user_ids whose devices have changed since `from_key`
"""
from_key = int(from_key)

# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
Expand Down Expand Up @@ -533,7 +532,7 @@ def _get_users_whose_devices_changed_txn(txn):
)

async def get_users_whose_signatures_changed(
self, user_id: str, from_key: str
self, user_id: str, from_key: int
) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since
`from_key`.
Expand All @@ -545,7 +544,7 @@ async def get_users_whose_signatures_changed(
Returns:
A set of user IDs with updated signatures.
"""
from_key = int(from_key)

if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
sql = """
SELECT DISTINCT user_ids FROM user_signature_stream
Expand Down
21 changes: 11 additions & 10 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
from_token: Optional[Tuple[int, int]],
to_token: Optional[Tuple[int, int]],
from_token: Optional[Tuple[Optional[int], int]],
to_token: Optional[Tuple[Optional[int], int]],
engine: BaseDatabaseEngine,
) -> str:
"""Creates an SQL expression to bound the columns by the pagination
Expand Down Expand Up @@ -535,13 +535,13 @@ async def get_recent_event_ids_for_room(
if limit == 0:
return [], end_token

end_token = RoomStreamToken.parse(end_token)
parsed_end_token = RoomStreamToken.parse(end_token)

rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
from_token=end_token,
from_token=parsed_end_token,
limit=limit,
)

Expand Down Expand Up @@ -989,8 +989,8 @@ def _paginate_room_events_txn(
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=from_token,
to_token=to_token,
from_token=from_token.as_tuple(),
to_token=to_token.as_tuple() if to_token else None,
engine=self.database_engine,
)

Expand Down Expand Up @@ -1083,16 +1083,17 @@ async def paginate_room_events(
and `to_key`).
"""

from_key = RoomStreamToken.parse(from_key)
parsed_from_key = RoomStreamToken.parse(from_key)
parsed_to_key = None
if to_key:
to_key = RoomStreamToken.parse(to_key)
parsed_to_key = RoomStreamToken.parse(to_key)

rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
from_key,
to_key,
parsed_from_key,
parsed_to_key,
direction,
limit,
event_filter,
Expand Down
152 changes: 78 additions & 74 deletions synapse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import string
import sys
from collections import namedtuple
from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar

import attr
from signedjson.key import decode_verify_key_bytes
Expand Down Expand Up @@ -362,38 +362,95 @@ def f2(m):
return username.decode("ascii")


class StreamToken(
namedtuple(
"Token",
(
"room_key",
"presence_key",
"typing_key",
"receipt_key",
"account_data_key",
"push_rules_key",
"to_device_key",
"device_list_key",
"groups_key",
),
@attr.s(frozen=True, slots=True)
class RoomStreamToken:
"""Tokens are positions between events. The token "s1" comes after event 1.

s0 s1
| |
[0] V [1] V [2]

Tokens can either be a point in the live event stream or a cursor going
through historic events.

When traversing the live event stream events are ordered by when they
arrived at the homeserver.

When traversing historic events the events are ordered by their depth in
the event graph "topological_ordering" and then by when they arrived at the
homeserver "stream_ordering".

Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""

topological = attr.ib(
type=Optional[int],
validator=attr.validators.optional(attr.validators.instance_of(int)),
)
):
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))

@classmethod
def parse(cls, string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
if string[0] == "t":
parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))

@classmethod
def parse_stream_token(cls, string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))

def as_tuple(self) -> Tuple[Optional[int], int]:
return (self.topological, self.stream)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why you used attr.astuple below (in StreamToken.to_string) and not here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only because I think astuple is a bit magical so I would like to avoid it for this, but StreamToken already acts a bit like a tuple in a number of places and so may as well just attr.astuple rather than writing it out. 🤷


def __str__(self) -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)


@attr.s(slots=True, frozen=True)
class StreamToken:
room_key = attr.ib(type=str)
presence_key = attr.ib(type=int)
typing_key = attr.ib(type=int)
receipt_key = attr.ib(type=int)
account_data_key = attr.ib(type=int)
push_rules_key = attr.ib(type=int)
to_device_key = attr.ib(type=int)
device_list_key = attr.ib(type=int)
groups_key = attr.ib(type=int)

_SEPARATOR = "_"
START = None # type: StreamToken

@classmethod
def from_string(cls, string):
try:
keys = string.split(cls._SEPARATOR)
while len(keys) < len(cls._fields):
while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key
keys.append("0")
return cls(*keys)
return cls(keys[0], *(int(k) for k in keys[1:]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this what allows the removal of the type casts in e.g. get_users_whose_devices_changed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. Yes. I seem to have squashed away the commit that changes from sometimes storing strings to always storing ints

except Exception:
raise SynapseError(400, "Invalid Token")

def to_string(self):
return self._SEPARATOR.join([str(k) for k in self])
return self._SEPARATOR.join([str(k) for k in attr.astuple(self)])

@property
def room_stream_id(self):
Expand Down Expand Up @@ -435,63 +492,10 @@ def copy_and_advance(self, key, new_value):
return self

def copy_and_replace(self, key, new_value):
return self._replace(**{key: new_value})


StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1)))


class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
"""Tokens are positions between events. The token "s1" comes after event 1.

s0 s1
| |
[0] V [1] V [2]

Tokens can either be a point in the live event stream or a cursor going
through historic events.

When traversing the live event stream events are ordered by when they
arrived at the homeserver.

When traversing historic events the events are ordered by their depth in
the event graph "topological_ordering" and then by when they arrived at the
homeserver "stream_ordering".

Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
return attr.evolve(self, **{key: new_value})

__slots__ = [] # type: list

@classmethod
def parse(cls, string):
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
if string[0] == "t":
parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))

@classmethod
def parse_stream_token(cls, string):
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))

def __str__(self):
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
StreamToken.START = StreamToken.from_string("s0_0")


class ThirdPartyInstanceID(
Expand Down