Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert more cached return values to immutable types #16356

Merged
merged 5 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16356.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
8 changes: 4 additions & 4 deletions synapse/api/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase, relation_from_event
from synapse.types import JsonDict, RoomID, UserID
from synapse.types import JsonDict, JsonMapping, RoomID, UserID

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -191,7 +191,7 @@ def check_valid_filter(self, user_filter_json: JsonDict) -> None:


class FilterCollection:
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
self._filter_json = filter_json

room_filter_json = self._filter_json.get("room", {})
Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(self, hs: "HomeServer", filter_json: JsonDict):
def __repr__(self) -> str:
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)

def get_filter_json(self) -> JsonDict:
def get_filter_json(self) -> JsonMapping:
return self._filter_json

def timeline_limit(self) -> int:
Expand Down Expand Up @@ -313,7 +313,7 @@ def blocks_all_room_timeline(self) -> bool:


class Filter:
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
self._hs = hs
self._store = hs.get_datastores().main
self.filter_json = filter_json
Expand Down
4 changes: 2 additions & 2 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from synapse.http.client import is_unknown_endpoint
from synapse.http.types import QueryParams
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
Expand Down Expand Up @@ -1704,7 +1704,7 @@ async def send_request(
async def timestamp_to_event(
self,
*,
destinations: List[str],
destinations: StrCollection,
room_id: str,
timestamp: int,
direction: Direction,
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,7 +1538,7 @@ async def _resync_device(self, sender: str) -> None:
logger.exception("Failed to resync device for %s", sender)

async def backfill_event_id(
self, destinations: List[str], room_id: str, event_id: str
self, destinations: StrCollection, room_id: str, event_id: str
) -> PulledPduInfo:
"""Backfill a single event and persist it as a non-outlier which means
we also pull in all of the state and auth events necessary for it.
Expand Down
14 changes: 12 additions & 2 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
# limitations under the License.
import enum
import logging
from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional
from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Mapping,
Optional,
Sequence,
)

import attr

Expand Down Expand Up @@ -245,7 +255,7 @@ async def redact_events_related_to(

async def get_references_for_events(
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
) -> Dict[str, List[_RelatedEvent]]:
) -> Mapping[str, Sequence[_RelatedEvent]]:
"""Get a list of references to the given events.

Args:
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, JsonMapping, UserID

from ._base import client_patterns, set_timeline_upper_limit

Expand All @@ -41,7 +41,7 @@ def __init__(self, hs: "HomeServer"):

async def on_GET(
self, request: SynapseRequest, user_id: str, filter_id: str
) -> Tuple[int, JsonDict]:
) -> Tuple[int, JsonMapping]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:

@trace
@tag_args
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
"""Get current hosts in room based on current state.

Blocks until we have full state for the given room. This only happens for rooms
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, JsonMapping, UserID
from synapse.util.caches.descriptors import cached

if TYPE_CHECKING:
Expand Down Expand Up @@ -145,7 +145,7 @@ def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
@cached(num_args=2)
async def get_user_filter(
self, user_id: UserID, filter_id: Union[int, str]
) -> JsonDict:
) -> JsonMapping:
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
@cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
async def get_references_for_events(
self, event_ids: Collection[str]
) -> Mapping[str, Optional[List[_RelatedEvent]]]:
) -> Mapping[str, Optional[Sequence[_RelatedEvent]]]:
"""Get a list of references to the given events.

Args:
Expand Down Expand Up @@ -931,7 +931,7 @@ async def get_threads(
room_id: str,
limit: int = 5,
from_token: Optional[ThreadsNextBatch] = None,
) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
) -> Tuple[Sequence[str], Optional[ThreadsNextBatch]]:
"""Get a list of thread IDs, ordered by topological ordering of their
latest reply.

Expand Down
10 changes: 6 additions & 4 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
)

@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
"""
Get current hosts in room based on current state.

Expand Down Expand Up @@ -1013,12 +1013,14 @@ async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
# `get_users_in_room` rather than funky SQL.

domains = await self.get_current_hosts_in_room(room_id)
return list(domains)
return tuple(domains)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.

def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
def get_current_hosts_in_room_ordered_txn(
txn: LoggingTransaction,
) -> Tuple[str, ...]:
# Returns a list of servers currently joined in the room sorted by
# longest in the room first (aka. with the lowest depth). The
# heuristic of sorting by servers who have been in the room the
Expand All @@ -1043,7 +1045,7 @@ def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
"""
txn.execute(sql, (room_id,))
# `server_domain` will be `NULL` for malformed MXIDs with no colons.
return [d for d, in txn if d is not None]
return tuple(d for d, in txn if d is not None)

return await self.db_pool.runInteraction(
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
Expand Down
35 changes: 19 additions & 16 deletions tests/util/caches/test_descriptors.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My head always melts when I look at this file.

Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import logging
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
Mapping,
NoReturn,
Optional,
Set,
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(self) -> None:
self.mock = mock.Mock()

@descriptors.cached(num_args=1)
def fn(self, arg1: int, arg2: int) -> mock.Mock:
def fn(self, arg1: int, arg2: int) -> str:
return self.mock(arg1, arg2)

obj = Cls()
Expand Down Expand Up @@ -228,8 +228,9 @@ class Cls:
call_count = 0

@cached()
def fn(self, arg1: int) -> Optional[Deferred]:
def fn(self, arg1: int) -> Deferred:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self.call_count += 1
assert self.result is not None
return self.result

obj = Cls()
Expand Down Expand Up @@ -401,31 +402,31 @@ def __init__(self) -> None:
self.mock = mock.Mock()

@descriptors.cached(iterable=True)
def fn(self, arg1: int, arg2: int) -> List[str]:
def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]:
return self.mock(arg1, arg2)

obj = Cls()

obj.mock.return_value = ["spam", "eggs"]
obj.mock.return_value = ("spam", "eggs")
r = obj.fn(1, 2)
self.assertEqual(r.result, ["spam", "eggs"])
self.assertEqual(r.result, ("spam", "eggs"))
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
obj.mock.return_value = ("chips",)
r = obj.fn(1, 3)
self.assertEqual(r.result, ["chips"])
self.assertEqual(r.result, ("chips",))
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()

# the two values should now be cached
self.assertEqual(len(obj.fn.cache.cache), 3)

r = obj.fn(1, 2)
self.assertEqual(r.result, ["spam", "eggs"])
self.assertEqual(r.result, ("spam", "eggs"))
r = obj.fn(1, 3)
self.assertEqual(r.result, ["chips"])
self.assertEqual(r.result, ("chips",))
obj.mock.assert_not_called()

def test_cache_iterable_with_sync_exception(self) -> None:
Expand Down Expand Up @@ -784,7 +785,9 @@ def fn(self, arg1: int, arg2: int) -> None:
pass

@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1: Iterable[int], arg2: int) -> Dict[int, str]:
async def list_fn(
self, args1: Iterable[int], arg2: int
) -> Mapping[int, str]:
context = current_context()
assert isinstance(context, LoggingContext)
assert context.name == "c1"
Expand Down Expand Up @@ -847,11 +850,11 @@ def fn(self, arg1: int) -> None:
pass

@descriptors.cachedList(cached_method_name="fn", list_name="args1")
def list_fn(self, args1: List[int]) -> "Deferred[dict]":
def list_fn(self, args1: List[int]) -> "Deferred[Mapping[int, str]]":
return self.mock(args1)

obj = Cls()
deferred_result: "Deferred[dict]" = Deferred()
deferred_result: "Deferred[Mapping[int, str]]" = Deferred()
obj.mock.return_value = deferred_result

# start off several concurrent lookups of the same key
Expand Down Expand Up @@ -890,7 +893,7 @@ def fn(self, arg1: int, arg2: int) -> None:
pass

@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1: List[int], arg2: int) -> Dict[int, str]:
async def list_fn(self, args1: List[int], arg2: int) -> Mapping[int, str]:
# we want this to behave like an asynchronous function
await run_on_reactor()
return self.mock(args1, arg2)
Expand Down Expand Up @@ -929,7 +932,7 @@ def fn(self, arg1: int) -> None:
pass

@cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args: List[int]) -> Dict[int, str]:
async def list_fn(self, args: List[int]) -> Mapping[int, str]:
await complete_lookup
return {arg: str(arg) for arg in args}

Expand Down Expand Up @@ -964,7 +967,7 @@ def fn(self, arg1: int) -> None:
pass

@cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args: List[int]) -> Dict[int, str]:
async def list_fn(self, args: List[int]) -> Mapping[int, str]:
await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished
return {arg: str(arg) for arg in args}
Expand Down