Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Prefer make_awaitable over defer.succeed in tests #12505

Merged
merged 9 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12505.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests.
24 changes: 17 additions & 7 deletions synapse/logging/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,17 +802,27 @@ def run_in_background( # type: ignore[misc]
# by synchronous exceptions, so let's turn them into Failures.
return defer.fail()

# First we handle coroutines by wrapping them in a `Deferred`.
if isinstance(res, typing.Coroutine):
res = defer.ensureDeferred(res)

# At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency.
# At this point, `res` may be a plain value, `Deferred`, or some other kind of
# non-coroutine awaitable.
if not isinstance(res, defer.Deferred):
# `res` is not a `Deferred` and not a `Coroutine`.
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
assert not isinstance(res, Awaitable)

return defer.succeed(res)
# Wrap plain values in a `Deferred`.
if not isinstance(res, Awaitable):
return defer.succeed(res)

# `res` is some kind of awaitable that is not a coroutine or `Deferred`.
# We assume that it is a completed awaitable, such as a `DoneAwaitable` or
# `Future` from `make_awaitable`, and await it manually.
iterator = res.__await__() # `__await__` returns an iterator...
try:
next(iterator)
raise ValueError(f"Function {f} returned an unresolved awaitable: {res}")
except StopIteration as e:
# ...which raises a `StopIteration` once the awaitable is complete.
return defer.succeed(e.value)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm very tempted to replace all this with:

    # `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain
    # value.
    if isinstance(res, typing.Coroutine):
        # Wrap the coroutine in a `Deferred`.
        res = defer.ensureDeferred(res)
    elif isinstance(res, defer.Deferred):
        pass
    elif isinstance(res, Awaitable):
        # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
        # or `Future` from `make_awaitable`.
        def awaiter(awaitable: Awaitable[R]):
            return await awaitable
        res = defer.ensureDeferred(awaiter(res))
    else:
        # `res` is a plain value. Wrap it in a `Deferred`.
        return defer.succeed(res)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ie. get rid of the nested ifs and spawn a coroutine to deal with weird awaitables that we don't quite know how to handle.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, that sounds good!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

       def awaiter(awaitable: Awaitable[R]):
           return await awaitable
       res = defer.ensureDeferred(awaiter(res))

if you're going to define a local function anyway, I'd go with:

        async def awaiter():
            return await res
        res = defer.ensureDeferred(awaiter())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave that a go and mypy wasn't convinced that res is an awaitable inside the function:

synapse/logging/context.py:816: error: Incompatible types in "await" (actual type "Union[R, Awaitable[R]]", expected type "Awaitable[Any]")  [misc]


if res.called and not res.paused:
# The function should have maintained the logcontext, so we can
Expand Down
2 changes: 1 addition & 1 deletion tests/federation/test_federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_get_room_state(self):
)

# mock up the response, and have the agent return it
self._mock_agent.request.return_value = defer.succeed(
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
clokep marked this conversation as resolved.
Show resolved Hide resolved
_mock_response(
{
"pdus": [
Expand Down
2 changes: 1 addition & 1 deletion tests/federation/test_federation_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_dont_send_device_updates_for_remote_users(self):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
self.hs.get_federation_transport_client().query_user_devices.return_value = (
defer.succeed(
make_awaitable(
{
"stream_id": "1",
"user_id": "@user2:host2",
Expand Down
7 changes: 3 additions & 4 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from parameterized import parameterized
from signedjson import key as key, sign as sign

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import RoomEncryptionAlgorithms
Expand Down Expand Up @@ -704,7 +703,7 @@ def test_query_devices_remote_no_sync(self) -> None:
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"

self.hs.get_federation_client().query_client_keys = mock.Mock(
return_value=defer.succeed(
return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
"master_keys": {
Expand Down Expand Up @@ -777,14 +776,14 @@ def test_query_devices_remote_sync(self) -> None:
# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
return_value=make_awaitable({"some_room_id"})
)

remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"

self.hs.get_federation_client().query_user_devices = mock.Mock(
return_value=defer.succeed(
return_value=make_awaitable(
{
"user_id": remote_user_id,
"stream_id": 1,
Expand Down
34 changes: 16 additions & 18 deletions tests/handlers/test_password_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from typing import Any, Type, Union
from unittest.mock import Mock

from twisted.internet import defer

import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
Expand Down Expand Up @@ -190,7 +188,7 @@ def password_only_auth_provider_login_test_body(self):
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)

# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
Expand Down Expand Up @@ -226,13 +224,13 @@ def password_only_auth_provider_ui_auth_test_body(self):
self.get_success(module_api.register_user("u"))

# log in twice, to get two devices
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()

# have the auth provider deny the request to start with
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)

# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
Expand All @@ -246,7 +244,7 @@ def password_only_auth_provider_ui_auth_test_body(self):
mock_password_provider.reset_mock()

# Finally, check the request goes through when we allow it
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
Expand All @@ -260,7 +258,7 @@ def local_user_fallback_login_test_body(self):
self.register_user("localuser", "localpass")

# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)

Expand All @@ -277,7 +275,7 @@ def local_user_fallback_ui_auth_test_body(self):
self.register_user("localuser", "localpass")

# have the auth provider deny the request
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)

# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
Expand Down Expand Up @@ -320,7 +318,7 @@ def no_local_user_fallback_login_test_body(self):
self.register_user("localuser", "localpass")

# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
Expand All @@ -342,7 +340,7 @@ def no_local_user_fallback_ui_auth_test_body(self):
self.register_user("localuser", "localpass")

# allow login via the auth provider
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)

# log in twice, to get two devices
tok1 = self.login("localuser", "p")
Expand All @@ -359,7 +357,7 @@ def no_local_user_fallback_ui_auth_test_body(self):
mock_password_provider.check_password.assert_not_called()

# now try deleting with the local password
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
Expand Down Expand Up @@ -413,7 +411,7 @@ def custom_auth_provider_login_test_body(self):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()

mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
Expand All @@ -427,7 +425,7 @@ def custom_auth_provider_login_test_body(self):
# try a weird username. Again, it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
Expand Down Expand Up @@ -477,7 +475,7 @@ def custom_auth_provider_ui_auth_test_body(self):
mock_password_provider.reset_mock()

# right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
body["auth"]["test_field"] = "foo"
Expand All @@ -490,7 +488,7 @@ def custom_auth_provider_ui_auth_test_body(self):
mock_password_provider.reset_mock()

# and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
Expand All @@ -508,9 +506,9 @@ def test_custom_auth_provider_callback(self):
self.custom_auth_provider_callback_test_body()

def custom_auth_provider_callback_test_body(self):
callback = Mock(return_value=defer.succeed(None))
callback = Mock(return_value=make_awaitable(None))

mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
Expand Down Expand Up @@ -646,7 +644,7 @@ def password_custom_auth_password_disabled_ui_auth_test_body(self):
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
Expand Down
6 changes: 3 additions & 3 deletions tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)

# we mock out the federation client too
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))

# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
Expand Down Expand Up @@ -98,7 +98,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

self.datastore = hs.get_datastores().main
self.datastore.get_destination_retry_timings = Mock(
return_value=defer.succeed(None)
return_value=make_awaitable(None)
)

self.datastore.get_device_updates_by_remote = Mock(
Expand Down
6 changes: 3 additions & 3 deletions tests/handlers/test_user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from unittest.mock import Mock, patch
from urllib.parse import quote

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

import synapse.rest.admin
Expand All @@ -30,6 +29,7 @@

from tests import unittest
from tests.storage.test_user_directory import GetUserDirectoryTables
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config

Expand Down Expand Up @@ -439,7 +439,7 @@ def test_handle_user_deactivated_support_user(self) -> None:
)
)

mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
Expand All @@ -454,7 +454,7 @@ def test_handle_user_deactivated_regular_user(self) -> None:
self.store.register_user(user_id=r_user_id, password_hash=None)
)

mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
Expand Down
4 changes: 2 additions & 2 deletions tests/rest/client/test_presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from http import HTTPStatus
from unittest.mock import Mock

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

from synapse.handlers.presence import PresenceHandler
Expand All @@ -24,6 +23,7 @@
from synapse.util import Clock

from tests import unittest
from tests.test_utils import make_awaitable


class PresenceTestCase(unittest.HomeserverTestCase):
Expand All @@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:

presence_handler = Mock(spec=PresenceHandler)
presence_handler.set_state.return_value = defer.succeed(None)
presence_handler.set_state.return_value = make_awaitable(None)

hs = self.setup_test_homeserver(
"red",
Expand Down
7 changes: 2 additions & 5 deletions tests/rest/client/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from unittest.mock import Mock, call
from urllib import parse as urlparse

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

import synapse.rest.admin
Expand Down Expand Up @@ -1426,9 +1425,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

def test_simple(self) -> None:
"Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
{}
)
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]

search_filter = {"generic_search_term": "foobar"}

Expand Down Expand Up @@ -1456,7 +1453,7 @@ def test_fallback(self) -> None:
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(404, "Not Found", b""),
defer.succeed({}),
make_awaitable({}),
)

search_filter = {"generic_search_term": "foobar"}
Expand Down
7 changes: 4 additions & 3 deletions tests/rest/client/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from synapse.util import Clock

from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import MockClock


Expand All @@ -38,7 +39,7 @@ def setUp(self) -> None:

@defer.inlineCallbacks
def test_executes_given_function(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg"
)
Expand All @@ -47,7 +48,7 @@ def test_executes_given_function(self):

@defer.inlineCallbacks
def test_deduplicates_based_on_key(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
Expand Down Expand Up @@ -130,7 +131,7 @@ def cb():

@defer.inlineCallbacks
def test_cleans_up(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
Expand Down
Loading