Skip to content

Commit

Permalink
Add missing type hints for tests.unittest. (matrix-org#13397)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored and azmeuk committed Aug 8, 2022
1 parent 1bac869 commit 74dcd62
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 52 deletions.
1 change: 1 addition & 0 deletions changelog.d/13397.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Adding missing type hints to tests.
12 changes: 2 additions & 10 deletions tests/handlers/test_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,17 +481,13 @@ def default_config(self) -> Dict[str, Any]:

return config

def prepare(
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass")

self.denied_user_id = self.register_user("denied", "pass")
self.denied_access_token = self.login("denied", "pass")

return hs

def test_denied_without_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
Expand Down Expand Up @@ -575,9 +571,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):

servlets = [directory.register_servlets, room.register_servlets]

def prepare(
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
room_id = self.helper.create_room_as(self.user_id)

channel = self.make_request(
Expand All @@ -588,8 +582,6 @@ def prepare(
self.room_list_handler = hs.get_room_list_handler()
self.directory_handler = hs.get_directory_handler()

return hs

def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True
Expand Down
6 changes: 4 additions & 2 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
participated, bundled_aggregations.get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
self.assertIn("latest_event", bundled_aggregations)
self.assert_dict(
{
"content": {
Expand All @@ -1072,7 +1073,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
"sender": self.user2_id,
"type": "m.room.test",
},
bundled_aggregations.get("latest_event"),
bundled_aggregations["latest_event"],
)

return assert_thread
Expand Down Expand Up @@ -1112,6 +1113,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertTrue(bundled_aggregations.get("current_user_participated"))
# The latest thread event has some fields that don't matter.
self.assertIn("latest_event", bundled_aggregations)
self.assert_dict(
{
"content": {
Expand All @@ -1124,7 +1126,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
"sender": self.user_id,
"type": "m.room.test",
},
bundled_aggregations.get("latest_event"),
bundled_aggregations["latest_event"],
)
# Check the unsigned field on the latest event.
self.assert_dict(
Expand Down
2 changes: 1 addition & 1 deletion tests/rest/client/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def test_get_state_cancellation(self) -> None:

self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertCountEqual(
[state_event["type"] for state_event in channel.json_body],
[state_event["type"] for state_event in channel.json_list],
{
"m.room.create",
"m.room.power_levels",
Expand Down
11 changes: 10 additions & 1 deletion tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Callable,
Dict,
Iterable,
List,
MutableMapping,
Optional,
Tuple,
Expand Down Expand Up @@ -121,7 +122,15 @@ def request(self, request: Request) -> None:

@property
def json_body(self) -> JsonDict:
return json.loads(self.text_body)
body = json.loads(self.text_body)
assert isinstance(body, dict)
return body

@property
def json_list(self) -> List[JsonDict]:
body = json.loads(self.text_body)
assert isinstance(body, list)
return body

@property
def text_body(self) -> str:
Expand Down
86 changes: 48 additions & 38 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Generic,
Iterable,
List,
NoReturn,
Optional,
Tuple,
Type,
Expand All @@ -39,7 +40,7 @@
import canonicaljson
import signedjson.key
import unpaddedbase64
from typing_extensions import Protocol
from typing_extensions import Concatenate, ParamSpec, Protocol

from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
Expand Down Expand Up @@ -67,7 +68,7 @@
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, UserID, create_requester
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree

Expand All @@ -88,6 +89,10 @@
TV = TypeVar("TV")
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)

P = ParamSpec("P")
R = TypeVar("R")
S = TypeVar("S")


class _TypedFailure(Generic[_ExcType], Protocol):
"""Extension to twisted.Failure, where the 'value' has a certain type."""
Expand All @@ -97,7 +102,7 @@ def value(self) -> _ExcType:
...


def around(target):
def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
"""A CLOS-style 'around' modifier, which wraps the original method of the
given instance with another piece of code.
Expand All @@ -106,11 +111,11 @@ def method_name(orig, *args, **kwargs):
return orig(*args, **kwargs)
"""

def _around(code):
def _around(code: Callable[Concatenate[S, P], R]) -> None:
name = code.__name__
orig = getattr(target, name)

def new(*args, **kwargs):
def new(*args: P.args, **kwargs: P.kwargs) -> R:
return code(orig, *args, **kwargs)

setattr(target, name, new)
Expand All @@ -131,7 +136,7 @@ def __init__(self, methodName: str):
level = getattr(method, "loglevel", getattr(self, "loglevel", None))

@around(self)
def setUp(orig):
def setUp(orig: Callable[[], R]) -> R:
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
if current_context():
Expand All @@ -144,7 +149,7 @@ def setUp(orig):
if level is not None and old_level != level:

@around(self)
def tearDown(orig):
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
logging.getLogger().setLevel(old_level)
return ret
Expand All @@ -158,7 +163,7 @@ def tearDown(orig):
return orig()

@around(self)
def tearDown(orig):
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
# force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs)
Expand All @@ -167,7 +172,7 @@ def tearDown(orig):

return ret

def assertObjectHasAttributes(self, attrs, obj):
def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None:
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEqual."""
for key in attrs.keys():
Expand All @@ -178,44 +183,44 @@ def assertObjectHasAttributes(self, attrs, obj):
except AssertionError as e:
raise (type(e))(f"Assert error for '.{key}':") from e

def assert_dict(self, required, actual):
def assert_dict(self, required: dict, actual: dict) -> None:
"""Does a partial assert of a dict.
Args:
required (dict): The keys and value which MUST be in 'actual'.
actual (dict): The test result. Extra keys will not be checked.
required: The keys and value which MUST be in 'actual'.
actual: The test result. Extra keys will not be checked.
"""
for key in required:
self.assertEqual(
required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
)


def DEBUG(target):
def DEBUG(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.DEBUG.
Can apply to either a TestCase or an individual test method."""
target.loglevel = logging.DEBUG
target.loglevel = logging.DEBUG # type: ignore[attr-defined]
return target


def INFO(target):
def INFO(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.INFO.
Can apply to either a TestCase or an individual test method."""
target.loglevel = logging.INFO
target.loglevel = logging.INFO # type: ignore[attr-defined]
return target


def logcontext_clean(target):
def logcontext_clean(target: TV) -> TV:
"""A decorator which marks the TestCase or method as 'logcontext_clean'
... ie, any logcontext errors should cause a test failure
"""

def logcontext_error(msg):
def logcontext_error(msg: str) -> NoReturn:
raise AssertionError("logcontext error: %s" % (msg))

patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
return patcher(target)
return patcher(target) # type: ignore[call-overload]


class HomeserverTestCase(TestCase):
Expand Down Expand Up @@ -255,7 +260,7 @@ def __init__(self, methodName: str):
method = getattr(self, methodName)
self._extra_config = getattr(method, "_extra_config", None)

def setUp(self):
def setUp(self) -> None:
"""
Set up the TestCase by calling the homeserver constructor, optionally
hijacking the authentication system to return a fixed user, and then
Expand Down Expand Up @@ -306,15 +311,21 @@ def setUp(self):
)
)

async def get_user_by_access_token(token=None, allow_guest=False):
async def get_user_by_access_token(
token: Optional[str] = None, allow_guest: bool = False
) -> JsonDict:
assert self.helper.auth_user_id is not None
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": token_id,
"is_guest": False,
}

async def get_user_by_req(request, allow_guest=False):
async def get_user_by_req(
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
UserID.from_string(self.helper.auth_user_id),
Expand All @@ -339,11 +350,11 @@ async def get_user_by_req(request, allow_guest=False):
if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs)

def tearDown(self):
def tearDown(self) -> None:
# Reset to not use frozen dicts.
events.USE_FROZEN_DICTS = False

def wait_on_thread(self, deferred, timeout=10):
def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
"""
Wait until a Deferred is done, where it's waiting on a real thread.
"""
Expand Down Expand Up @@ -374,7 +385,7 @@ def make_homeserver(self, reactor, clock):
clock (synapse.util.Clock): The Clock, associated with the reactor.
Returns:
A homeserver (synapse.server.HomeServer) suitable for testing.
A homeserver suitable for testing.
Function to be overridden in subclasses.
"""
Expand Down Expand Up @@ -408,7 +419,7 @@ def create_resource_dict(self) -> Dict[str, Resource]:
"/_synapse/admin": servlet_resource,
}

def default_config(self):
def default_config(self) -> JsonDict:
"""
Get a default HomeServer config dict.
"""
Expand All @@ -421,7 +432,9 @@ def default_config(self):

return config

def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
"""
Prepare for the test. This involves things like mocking out parts of
the homeserver, or building test data common across the whole test
Expand Down Expand Up @@ -519,7 +532,7 @@ def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
config_obj.parse_config_dict(config, "", "")
kwargs["config"] = config_obj

async def run_bg_updates():
async def run_bg_updates() -> None:
with LoggingContext("run_bg_updates"):
self.get_success(stor.db_pool.updates.run_background_updates(False))

Expand All @@ -538,11 +551,7 @@ def pump(self, by: float = 0.0) -> None:
"""
self.reactor.pump([by] * 100)

def get_success(
self,
d: Awaitable[TV],
by: float = 0.0,
) -> TV:
def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
self.pump(by=by)
return self.successResultOf(deferred)
Expand Down Expand Up @@ -755,7 +764,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
OTHER_SERVER_NAME = "other.example.com"
OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)

# poke the other server's signing key into the key store, so that we don't
Expand Down Expand Up @@ -879,7 +888,7 @@ def _auth_header_for_request(
)


def override_config(extra_config):
def override_config(extra_config: JsonDict) -> Callable[[TV], TV]:
"""A decorator which can be applied to test functions to give additional HS config
For use
Expand All @@ -892,12 +901,13 @@ def test_foo(self):
...
Args:
extra_config(dict): Additional config settings to be merged into the default
extra_config: Additional config settings to be merged into the default
config dict before instantiating the test homeserver.
"""

def decorator(func):
func._extra_config = extra_config
def decorator(func: TV) -> TV:
# This attribute is being defined.
func._extra_config = extra_config # type: ignore[attr-defined]
return func

return decorator
Expand Down

0 comments on commit 74dcd62

Please sign in to comment.