From 1674b7fa701f1d9db1023bb8a95677a2f61cb76e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jul 2022 13:49:19 -0400 Subject: [PATCH 01/11] Add missing type hints for override_config. --- tests/unittest.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/unittest.py b/tests/unittest.py index 66ce92f4a6e4..c41d92a56a12 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -879,7 +879,7 @@ def _auth_header_for_request( ) -def override_config(extra_config): +def override_config(extra_config: JsonDict) -> Callable[[TV], TV]: """A decorator which can be applied to test functions to give additional HS config For use @@ -892,12 +892,13 @@ def test_foo(self): ... Args: - extra_config(dict): Additional config settings to be merged into the default + extra_config: Additional config settings to be merged into the default config dict before instantiating the test homeserver. """ - def decorator(func): - func._extra_config = extra_config + def decorator(func: TV) -> TV: + # This attribute is being defined. + func._extra_config = extra_config # type: ignore[attr-defined] return func return decorator From 6c0344a4adc9422a8cd8a137ebae4e339557929b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jul 2022 14:19:29 -0400 Subject: [PATCH 02/11] Add type hints to prepare. --- tests/handlers/test_directory.py | 12 ++---------- tests/unittest.py | 6 ++++-- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 53d49ca89617..3b72c4c9d019 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -481,17 +481,13 @@ def default_config(self) -> Dict[str, Any]: return config - def prepare( - self, reactor: MemoryReactor, clock: Clock, hs: HomeServer - ) -> HomeServer: + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.allowed_user_id = self.register_user(self.allowed_localpart, "pass") self.allowed_access_token = self.login(self.allowed_localpart, "pass") self.denied_user_id = self.register_user("denied", "pass") self.denied_access_token = self.login("denied", "pass") - return hs - def test_denied_without_publication_permission(self) -> None: """ Try to create a room, register an alias for it, and publish it, @@ -575,9 +571,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def prepare( - self, reactor: MemoryReactor, clock: Clock, hs: HomeServer - ) -> HomeServer: + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -588,8 +582,6 @@ def prepare( self.room_list_handler = hs.get_room_list_handler() self.directory_handler = hs.get_directory_handler() - return hs - def test_disabling_room_list(self) -> None: self.room_list_handler.enable_room_list_search = True self.directory_handler.enable_room_list_search = True diff --git a/tests/unittest.py b/tests/unittest.py index c41d92a56a12..ec022b093628 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -421,7 +421,9 @@ def default_config(self): return config - def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: """ Prepare for the test. This involves things like mocking out parts of the homeserver, or building test data common across the whole test @@ -755,7 +757,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): OTHER_SERVER_NAME = "other.example.com" OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test") - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) # poke the other server's signing key into the key store, so that we don't From e36b12d4b5fd1af8060b38e8d84e96f488ae3891 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jul 2022 14:16:16 -0400 Subject: [PATCH 03/11] Add type hints to assert_dict. --- tests/rest/client/test_relations.py | 6 ++++-- tests/unittest.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index ad03eee17bc8..d589f073143b 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1060,6 +1060,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None: participated, bundled_aggregations.get("current_user_participated") ) # The latest thread event has some fields that don't matter. + self.assertIn("latest_event", bundled_aggregations) self.assert_dict( { "content": { @@ -1072,7 +1073,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None: "sender": self.user2_id, "type": "m.room.test", }, - bundled_aggregations.get("latest_event"), + bundled_aggregations["latest_event"], ) return assert_thread @@ -1112,6 +1113,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None: self.assertEqual(2, bundled_aggregations.get("count")) self.assertTrue(bundled_aggregations.get("current_user_participated")) # The latest thread event has some fields that don't matter. + self.assertIn("latest_event", bundled_aggregations) self.assert_dict( { "content": { @@ -1124,7 +1126,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None: "sender": self.user_id, "type": "m.room.test", }, - bundled_aggregations.get("latest_event"), + bundled_aggregations["latest_event"], ) # Check the unsigned field on the latest event. self.assert_dict( diff --git a/tests/unittest.py b/tests/unittest.py index ec022b093628..0bdb946bb9ae 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -178,12 +178,12 @@ def assertObjectHasAttributes(self, attrs, obj): except AssertionError as e: raise (type(e))(f"Assert error for '.{key}':") from e - def assert_dict(self, required, actual): + def assert_dict(self, required: dict, actual: dict) -> None: """Does a partial assert of a dict. Args: - required (dict): The keys and value which MUST be in 'actual'. - actual (dict): The test result. Extra keys will not be checked. + required: The keys and value which MUST be in 'actual'. + actual: The test result. Extra keys will not be checked. """ for key in required: self.assertEqual( From e4012f08ed4ba2179328ff8c253d8f4a87bbb6a4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jul 2022 14:30:21 -0400 Subject: [PATCH 04/11] Add type hints to around. --- tests/unittest.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/unittest.py b/tests/unittest.py index 0bdb946bb9ae..582ef3af2b32 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -39,7 +39,7 @@ import canonicaljson import signedjson.key import unpaddedbase64 -from typing_extensions import Protocol +from typing_extensions import ParamSpec, Protocol from twisted.internet.defer import Deferred, ensureDeferred from twisted.python.failure import Failure @@ -88,6 +88,9 @@ TV = TypeVar("TV") _ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True) +P = ParamSpec("P") +R = TypeVar("R") + class _TypedFailure(Generic[_ExcType], Protocol): """Extension to twisted.Failure, where the 'value' has a certain type.""" @@ -97,7 +100,7 @@ def value(self) -> _ExcType: ... -def around(target): +def around(target: TV) -> Callable[[Callable[P, R]], None]: """A CLOS-style 'around' modifier, which wraps the original method of the given instance with another piece of code. @@ -106,11 +109,11 @@ def method_name(orig, *args, **kwargs): return orig(*args, **kwargs) """ - def _around(code): + def _around(code: Callable[P, R]) -> None: name = code.__name__ orig = getattr(target, name) - def new(*args, **kwargs): + def new(*args: P.args, **kwargs: P.kwargs) -> R: return code(orig, *args, **kwargs) setattr(target, name, new) @@ -131,7 +134,7 @@ def __init__(self, methodName: str): level = getattr(method, "loglevel", getattr(self, "loglevel", None)) @around(self) - def setUp(orig): + def setUp(orig: Callable[[], R]) -> R: # if we're not starting in the sentinel logcontext, then to be honest # all future bets are off. if current_context(): @@ -144,7 +147,7 @@ def setUp(orig): if level is not None and old_level != level: @around(self) - def tearDown(orig): + def tearDown(orig: Callable[[], R]) -> R: ret = orig() logging.getLogger().setLevel(old_level) return ret @@ -158,7 +161,7 @@ def tearDown(orig): return orig() @around(self) - def tearDown(orig): + def tearDown(orig: Callable[[], R]) -> R: ret = orig() # force a GC to workaround problems with deferreds leaking logcontexts when # they are GCed (see the logcontext docs) From 4853dd51eb9c3f4bfe7e2ab5167a6fe81a9a46f4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jul 2022 14:32:49 -0400 Subject: [PATCH 05/11] Add type hints to logging methods. --- tests/rest/client/test_rooms.py | 4 +++- tests/unittest.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index c45cb3209025..f99756499e7c 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -495,8 +495,10 @@ def test_get_state_cancellation(self) -> None: ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + # json_body is defined as JsonDict, but it can be any valid JSON. + json_body: List[JsonDict] = channel.json_body # type: ignore[assignment] self.assertCountEqual( - [state_event["type"] for state_event in channel.json_body], + [state_event["type"] for state_event in json_body], { "m.room.create", "m.room.power_levels", diff --git a/tests/unittest.py b/tests/unittest.py index 582ef3af2b32..613ad369a79f 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -29,6 +29,7 @@ Iterable, List, Optional, + NoReturn, Tuple, Type, TypeVar, @@ -194,31 +195,31 @@ def assert_dict(self, required: dict, actual: dict) -> None: ) -def DEBUG(target): +def DEBUG(target: TV) -> TV: """A decorator to set the .loglevel attribute to logging.DEBUG. Can apply to either a TestCase or an individual test method.""" - target.loglevel = logging.DEBUG + target.loglevel = logging.DEBUG # type: ignore[attr-defined] return target -def INFO(target): +def INFO(target: TV) -> TV: """A decorator to set the .loglevel attribute to logging.INFO. Can apply to either a TestCase or an individual test method.""" - target.loglevel = logging.INFO + target.loglevel = logging.INFO # type: ignore[attr-defined] return target -def logcontext_clean(target): +def logcontext_clean(target: TV) -> TV: """A decorator which marks the TestCase or method as 'logcontext_clean' ... ie, any logcontext errors should cause a test failure """ - def logcontext_error(msg): + def logcontext_error(msg: str) -> NoReturn: raise AssertionError("logcontext error: %s" % (msg)) patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error) - return patcher(target) + return patcher(target) # type: ignore[call-overload] class HomeserverTestCase(TestCase): From 76270c3ceada228ca9220c946ed2184500c17b2c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jul 2022 14:41:07 -0400 Subject: [PATCH 06/11] Add missing type hints. --- tests/unittest.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/tests/unittest.py b/tests/unittest.py index 613ad369a79f..389275af81e5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -68,7 +68,7 @@ from synapse.rest import RegisterServletsFunc from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree @@ -171,7 +171,7 @@ def tearDown(orig: Callable[[], R]) -> R: return ret - def assertObjectHasAttributes(self, attrs, obj): + def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None: """Asserts that the given object has each of the attributes given, and that the value of each matches according to assertEqual.""" for key in attrs.keys(): @@ -259,7 +259,7 @@ def __init__(self, methodName: str): method = getattr(self, methodName) self._extra_config = getattr(method, "_extra_config", None) - def setUp(self): + def setUp(self) -> None: """ Set up the TestCase by calling the homeserver constructor, optionally hijacking the authentication system to return a fixed user, and then @@ -310,7 +310,9 @@ def setUp(self): ) ) - async def get_user_by_access_token(token=None, allow_guest=False): + async def get_user_by_access_token( + token: Optional[str] = None, allow_guest: bool = False + ) -> JsonDict: assert self.helper.auth_user_id is not None return { "user": UserID.from_string(self.helper.auth_user_id), @@ -318,7 +320,11 @@ async def get_user_by_access_token(token=None, allow_guest=False): "is_guest": False, } - async def get_user_by_req(request, allow_guest=False): + async def get_user_by_req( + request: SynapseRequest, + allow_guest: bool = False, + allow_expired: bool = False, + ) -> Requester: assert self.helper.auth_user_id is not None return create_requester( UserID.from_string(self.helper.auth_user_id), @@ -343,11 +349,11 @@ async def get_user_by_req(request, allow_guest=False): if hasattr(self, "prepare"): self.prepare(self.reactor, self.clock, self.hs) - def tearDown(self): + def tearDown(self) -> None: # Reset to not use frozen dicts. events.USE_FROZEN_DICTS = False - def wait_on_thread(self, deferred, timeout=10): + def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None: """ Wait until a Deferred is done, where it's waiting on a real thread. """ @@ -378,7 +384,7 @@ def make_homeserver(self, reactor, clock): clock (synapse.util.Clock): The Clock, associated with the reactor. Returns: - A homeserver (synapse.server.HomeServer) suitable for testing. + A homeserver suitable for testing. Function to be overridden in subclasses. """ @@ -412,7 +418,7 @@ def create_resource_dict(self) -> Dict[str, Resource]: "/_synapse/admin": servlet_resource, } - def default_config(self): + def default_config(self) -> JsonDict: """ Get a default HomeServer config dict. """ @@ -525,7 +531,7 @@ def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer: config_obj.parse_config_dict(config, "", "") kwargs["config"] = config_obj - async def run_bg_updates(): + async def run_bg_updates() -> None: with LoggingContext("run_bg_updates"): self.get_success(stor.db_pool.updates.run_background_updates(False)) @@ -544,11 +550,7 @@ def pump(self, by: float = 0.0) -> None: """ self.reactor.pump([by] * 100) - def get_success( - self, - d: Awaitable[TV], - by: float = 0.0, - ) -> TV: + def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV: deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type] self.pump(by=by) return self.successResultOf(deferred) From 5a21ef7388ca7b406cc65c5e0f1dc99aef53dc87 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jul 2022 15:10:18 -0400 Subject: [PATCH 07/11] Newsfragment --- changelog.d/13397.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/13397.misc diff --git a/changelog.d/13397.misc b/changelog.d/13397.misc new file mode 100644 index 000000000000..8dc610d9e290 --- /dev/null +++ b/changelog.d/13397.misc @@ -0,0 +1 @@ +Adding missing type hints to tests. From 700dfcf22f20e8254bb42346b01c7a03a3a5f1ad Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jul 2022 15:20:42 -0400 Subject: [PATCH 08/11] Lint --- tests/unittest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest.py b/tests/unittest.py index 389275af81e5..548b9fa4c707 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -28,8 +28,8 @@ Generic, Iterable, List, - Optional, NoReturn, + Optional, Tuple, Type, TypeVar, From aca8cf6808cb22c7186f6ebc5011790033bcbec3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 27 Jul 2022 07:29:20 -0400 Subject: [PATCH 09/11] Add a separate json_list method on FakeChannel. --- tests/rest/client/test_rooms.py | 4 +--- tests/server.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index f99756499e7c..2272d55d84ec 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -495,10 +495,8 @@ def test_get_state_cancellation(self) -> None: ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) - # json_body is defined as JsonDict, but it can be any valid JSON. - json_body: List[JsonDict] = channel.json_body # type: ignore[assignment] self.assertCountEqual( - [state_event["type"] for state_event in json_body], + [state_event["type"] for state_event in channel.json_list], { "m.room.create", "m.room.power_levels", diff --git a/tests/server.py b/tests/server.py index df3f1564c9ea..9689e6a0cdc5 100644 --- a/tests/server.py +++ b/tests/server.py @@ -25,6 +25,7 @@ Callable, Dict, Iterable, + List, MutableMapping, Optional, Tuple, @@ -121,7 +122,15 @@ def request(self, request: Request) -> None: @property def json_body(self) -> JsonDict: - return json.loads(self.text_body) + body = json.loads(self.text_body) + assert isinstance(body, dict) + return body + + @property + def json_list(self) -> List[JsonDict]: + body = json.loads(self.text_body) + assert isinstance(body, list) + return body @property def text_body(self) -> str: From 603d70acd1fc1c5e479dff59e41d7fa73beea041 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 27 Jul 2022 12:25:02 -0400 Subject: [PATCH 10/11] Additional corrections from review. --- tests/unittest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unittest.py b/tests/unittest.py index 548b9fa4c707..711c34b33023 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -40,7 +40,7 @@ import canonicaljson import signedjson.key import unpaddedbase64 -from typing_extensions import ParamSpec, Protocol +from typing_extensions import Concatenate, ParamSpec, Protocol from twisted.internet.defer import Deferred, ensureDeferred from twisted.python.failure import Failure @@ -101,7 +101,7 @@ def value(self) -> _ExcType: ... -def around(target: TV) -> Callable[[Callable[P, R]], None]: +def around(target: TV) -> Callable[[Callable[Concatenate[TV, P], R]], None]: """A CLOS-style 'around' modifier, which wraps the original method of the given instance with another piece of code. @@ -110,7 +110,7 @@ def method_name(orig, *args, **kwargs): return orig(*args, **kwargs) """ - def _around(code: Callable[P, R]) -> None: + def _around(code: Callable[Concatenate[TV, P], R]) -> None: name = code.__name__ orig = getattr(target, name) From 74d6acb1da2b1ca6f34adb896fe020a7025c2fcb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 27 Jul 2022 12:52:35 -0400 Subject: [PATCH 11/11] Use another typevar. --- tests/unittest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unittest.py b/tests/unittest.py index 711c34b33023..bec4a3d02396 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -91,6 +91,7 @@ P = ParamSpec("P") R = TypeVar("R") +S = TypeVar("S") class _TypedFailure(Generic[_ExcType], Protocol): @@ -101,7 +102,7 @@ def value(self) -> _ExcType: ... -def around(target: TV) -> Callable[[Callable[Concatenate[TV, P], R]], None]: +def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]: """A CLOS-style 'around' modifier, which wraps the original method of the given instance with another piece of code. @@ -110,7 +111,7 @@ def method_name(orig, *args, **kwargs): return orig(*args, **kwargs) """ - def _around(code: Callable[Concatenate[TV, P], R]) -> None: + def _around(code: Callable[Concatenate[S, P], R]) -> None: name = code.__name__ orig = getattr(target, name)