Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve type hints for cached decorator #15658

Merged
merged 6 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15658.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
34 changes: 33 additions & 1 deletion scripts-dev/mypy_synapse_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@

from typing import Callable, Optional, Type

from mypy.erasetype import remove_instance_last_known_values
from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType, NoneType, UnionType
from mypy.types import CallableType, Instance, NoneType, UnionType


class SynapsePlugin(Plugin):
Expand Down Expand Up @@ -92,10 +93,41 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
arg_names.append("on_invalidate")
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.

# Finally we ensure the return type is a Deferred.
if (
isinstance(signature.ret_type, Instance)
and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred"
):
# If it is already a Deferred, nothing to do.
ret_type = signature.ret_type
else:
ret_arg = None
if isinstance(signature.ret_type, Instance):
# If a coroutine, wrap the coroutine's return type in a Deferred.
if signature.ret_type.type.fullname == "typing.Coroutine":
ret_arg = signature.ret_type.args[2]

# If an awaitable, wrap the awaitable's final value in a Deferred.
elif signature.ret_type.type.fullname == "typing.Awaitable":
ret_arg = signature.ret_type.args[0]

# Otherwise, wrap the return value in a Deferred.
if ret_arg is None:
ret_arg = signature.ret_type

# This should be able to use ctx.api.named_generic_type, but that doesn't seem
# to find the correct symbol for anything more than 1 module deep.
#
# modules is not part of CheckerPluginInterface. The following is a combination
# of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined]
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])

signature = signature.copy_modified(
arg_types=arg_types,
arg_names=arg_names,
arg_kinds=arg_kinds,
ret_type=ret_type,
)

return signature
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ async def _get_joined_hosts(
# `get_joined_hosts` is called with the "current" state group for the
# room, and so consecutive calls will be for consecutive state groups
# which point to the previous state group.
cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc]
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
cache = await self._get_joined_hosts_cache(room_id)

# If the state group in the cache matches, we already have the data we need.
if state_entry.state_group == cache.state_group:
Expand Down
6 changes: 4 additions & 2 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def __init__(
self.iterable = iterable
self.prune_unread_entries = prune_unread_entries

def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
def __get__(
self, obj: Optional[Any], owner: Optional[Type]
) -> Callable[..., "defer.Deferred[Any]"]:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.name,
max_entries=self.max_entries,
Expand All @@ -232,7 +234,7 @@ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., An
get_cache_key = self.cache_key_builder

@functools.wraps(self.orig)
def _wrapped(*args: Any, **kwargs: Any) -> Any:
def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]":
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
Expand Down
82 changes: 29 additions & 53 deletions tests/appservice/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Generator
from typing import Any, Generator
from unittest.mock import Mock

from twisted.internet import defer
Expand Down Expand Up @@ -49,93 +49,81 @@ def setUp(self) -> None:
@defer.inlineCallbacks
def test_regex_user_id_prefix_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deferred[object] can't yield Deferred[bool], but using Deferred[bool] doesn't work with inlineCallbacks.

self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is already a Deferred, so we don't need to wrap it in ensureDeferred here.

)
)

@defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org"
self.assertFalse(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_regex_room_member_is_checked(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_regex_room_id_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_regex_room_id_no_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_regex_alias_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
Expand All @@ -145,10 +133,8 @@ def test_regex_alias_match(
self.store.get_local_users_in_room = simple_async_mock([])
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
Expand Down Expand Up @@ -192,7 +178,7 @@ def test_exclusive_room(self) -> None:
@defer.inlineCallbacks
def test_regex_alias_no_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
Expand All @@ -213,7 +199,7 @@ def test_regex_alias_no_match(
@defer.inlineCallbacks
def test_regex_multiple_matches(
self,
) -> Generator["defer.Deferred[object]", object, None]:
) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
Expand All @@ -223,18 +209,14 @@ def test_regex_multiple_matches(
self.store.get_local_users_in_room = simple_async_mock([])
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_interested_in_self(
self,
) -> Generator["defer.Deferred[object]", object, None]:
def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]:
# make sure invites get through
self.service.sender = "@appservice:name"
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
Expand All @@ -243,18 +225,14 @@ def test_interested_in_self(
self.event.state_key = self.service.sender
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)

@defer.inlineCallbacks
def test_member_list_match(
self,
) -> Generator["defer.Deferred[object]", object, None]:
def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
# Note that @irc_fo:here is the AS user.
self.store.get_local_users_in_room = simple_async_mock(
Expand All @@ -265,10 +243,8 @@ def test_member_list_match(
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
yield self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
11 changes: 5 additions & 6 deletions tests/storage/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ def test_get_set_transactions(self) -> None:
destination retries, as well as testing tht we can set and get
correctly.
"""
d = self.store.get_destination_retry_timings("example.com")
r = self.get_success(d)
r = self.get_success(self.store.get_destination_retry_timings("example.com"))
self.assertIsNone(r)

d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
self.get_success(d)
self.get_success(
self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
)

d = self.store.get_destination_retry_timings("example.com")
r = self.get_success(d)
r = self.get_success(self.store.get_destination_retry_timings("example.com"))
Comment on lines -37 to +43
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get and set versions now properly have different return types, I just inlined them instead of having a d and a d2.


self.assertEqual(
DestinationRetryTimings(
Expand Down