Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints in tests #14879

Merged
merged 5 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ exclude = (?x)
|tests/http/federation/test_matrix_federation_agent.py
|tests/http/federation/test_srv_resolver.py
|tests/http/test_proxyagent.py
|tests/logging/__init__.py
|tests/logging/test_terse_json.py
|tests/module_api/test_api.py
|tests/rest/client/test_transactions.py
|tests/rest/media/v1/test_media_storage.py
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py
Expand Down Expand Up @@ -95,6 +92,9 @@ disallow_untyped_defs = True
[mypy-tests.handlers.*]
disallow_untyped_defs = True

[mypy-tests.logging.*]
disallow_untyped_defs = True

[mypy-tests.metrics.*]
disallow_untyped_defs = True

Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/client/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from typing_extensions import ParamSpec

from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.web.server import Request

Expand Down Expand Up @@ -90,7 +91,7 @@ def fetch_or_execute(
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: P.args,
**kwargs: P.kwargs,
) -> Awaitable[Tuple[int, JsonDict]]:
) -> "Deferred[Tuple[int, JsonDict]]":
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.

Expand Down
6 changes: 4 additions & 2 deletions tests/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.
import logging

from tests.unittest import TestCase

class LoggerCleanupMixin:
def get_logger(self, handler):

class LoggerCleanupMixin(TestCase):
def get_logger(self, handler: logging.Handler) -> logging.Logger:
"""
Attach a handler to a logger and add clean-ups to remove revert this.
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/logging/test_opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_overlapping_spans(self) -> None:

scopes = []

async def task(i: int):
async def task(i: int) -> None:
scope = start_active_span(
f"task{i}",
tracer=self._tracer,
Expand All @@ -165,7 +165,7 @@ async def task(i: int):
self.assertEqual(self._tracer.active_span, scope.span)
scope.close()

async def root():
async def root() -> None:
with start_active_span("root span", tracer=self._tracer) as root_scope:
self.assertEqual(self._tracer.active_span, root_scope.span)
scopes.append(root_scope)
Expand Down
25 changes: 17 additions & 8 deletions tests/logging/test_remote_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.test.proto_helpers import AccumulatingProtocol
from typing import Tuple

from twisted.internet.protocol import Protocol
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock

from synapse.logging import RemoteHandler

Expand All @@ -20,7 +23,9 @@
from tests.unittest import TestCase


def connect_logging_client(reactor, client_id):
def connect_logging_client(
reactor: MemoryReactorClock, client_id: int
) -> Tuple[Protocol, AccumulatingProtocol]:
# This is essentially tests.server.connect_client, but disabling autoflush on
# the client transport. This is necessary to avoid an infinite loop due to
# sending of data via the logging transport causing additional logs to be
Expand All @@ -35,10 +40,10 @@ def connect_logging_client(reactor, client_id):


class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
def setUp(self):
def setUp(self) -> None:
self.reactor, _ = get_clock()

def test_log_output(self):
def test_log_output(self) -> None:
"""
The remote handler delivers logs over TCP.
"""
Expand All @@ -51,6 +56,7 @@ def test_log_output(self):
client, server = connect_logging_client(self.reactor, 0)

# Trigger data being sent
assert isinstance(client.transport, FakeTransport)
client.transport.flush()

# One log message, with a single trailing newline
Expand All @@ -61,7 +67,7 @@ def test_log_output(self):
# Ensure the data passed through properly.
self.assertEqual(logs[0], "Hello there, wally!")

def test_log_backpressure_debug(self):
def test_log_backpressure_debug(self) -> None:
"""
When backpressure is hit, DEBUG logs will be shed.
"""
Expand All @@ -83,14 +89,15 @@ def test_log_backpressure_debug(self):

# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
assert isinstance(client.transport, FakeTransport)
client.transport.flush()

# Only the 7 infos made it through, the debugs were elided
logs = server.data.splitlines()
self.assertEqual(len(logs), 7)
self.assertNotIn(b"debug", server.data)

def test_log_backpressure_info(self):
def test_log_backpressure_info(self) -> None:
"""
When backpressure is hit, DEBUG and INFO logs will be shed.
"""
Expand All @@ -116,6 +123,7 @@ def test_log_backpressure_info(self):

# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
assert isinstance(client.transport, FakeTransport)
client.transport.flush()

# The 10 warnings made it through, the debugs and infos were elided
Expand All @@ -124,7 +132,7 @@ def test_log_backpressure_info(self):
self.assertNotIn(b"debug", server.data)
self.assertNotIn(b"info", server.data)

def test_log_backpressure_cut_middle(self):
def test_log_backpressure_cut_middle(self) -> None:
"""
When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
it will cut the middle messages out.
Expand All @@ -140,6 +148,7 @@ def test_log_backpressure_cut_middle(self):

# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
assert isinstance(client.transport, FakeTransport)
client.transport.flush()

# The first five and last five warnings made it through, the debugs and
Expand All @@ -151,7 +160,7 @@ def test_log_backpressure_cut_middle(self):
logs,
)

def test_cancel_connection(self):
def test_cancel_connection(self) -> None:
"""
Gracefully handle the connection being cancelled.
"""
Expand Down
30 changes: 18 additions & 12 deletions tests/logging/test_terse_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,36 @@
import json
import logging
from io import BytesIO, StringIO
from typing import cast
from unittest.mock import Mock, patch

from twisted.web.http import HTTPChannel
from twisted.web.server import Request

from synapse.http.site import SynapseRequest
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
from synapse.logging.context import LoggingContext, LoggingContextFilter
from synapse.types import JsonDict

from tests.logging import LoggerCleanupMixin
from tests.server import FakeChannel
from tests.server import FakeChannel, get_clock
from tests.unittest import TestCase


class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
def setUp(self):
def setUp(self) -> None:
self.output = StringIO()
self.reactor, _ = get_clock()

def get_log_line(self):
def get_log_line(self) -> JsonDict:
# One log message, with a single trailing newline.
data = self.output.getvalue()
logs = data.splitlines()
self.assertEqual(len(logs), 1)
self.assertEqual(data.count("\n"), 1)
return json.loads(logs[0])

def test_terse_json_output(self):
def test_terse_json_output(self) -> None:
"""
The Terse JSON formatter converts log messages to JSON.
"""
Expand All @@ -61,7 +65,7 @@ def test_terse_json_output(self):
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")

def test_extra_data(self):
def test_extra_data(self) -> None:
"""
Additional information can be included in the structured logging.
"""
Expand Down Expand Up @@ -93,7 +97,7 @@ def test_extra_data(self):
self.assertEqual(log["int"], 3)
self.assertIs(log["bool"], True)

def test_json_output(self):
def test_json_output(self) -> None:
"""
The Terse JSON formatter converts log messages to JSON.
"""
Expand All @@ -114,7 +118,7 @@ def test_json_output(self):
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")

def test_with_context(self):
def test_with_context(self) -> None:
"""
The logging context should be added to the JSON response.
"""
Expand All @@ -139,7 +143,7 @@ def test_with_context(self):
self.assertEqual(log["log"], "Hello there, wally!")
self.assertEqual(log["request"], "name")

def test_with_request_context(self):
def test_with_request_context(self) -> None:
"""
Information from the logging context request should be added to the JSON response.
"""
Expand All @@ -154,11 +158,13 @@ def test_with_request_context(self):
site.server_version_string = "Server v1"
site.reactor = Mock()
site.experimental_cors_msc3886 = False
request = SynapseRequest(FakeChannel(site, None), site)
request = SynapseRequest(
cast(HTTPChannel, FakeChannel(site, self.reactor)), site
)
# Call requestReceived to finish instantiating the object.
request.content = BytesIO()
# Partially skip some of the internal processing of SynapseRequest.
request._started_processing = Mock()
# Partially skip some internal processing of SynapseRequest.
request._started_processing = Mock() # type: ignore[assignment]
request.request_metrics = Mock(spec=["name"])
with patch.object(Request, "render"):
request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
Expand Down Expand Up @@ -200,7 +206,7 @@ def test_with_request_context(self):
self.assertEqual(log["protocol"], "1.1")
self.assertEqual(log["user_agent"], "")

def test_with_exception(self):
def test_with_exception(self) -> None:
"""
The logging exception type & value should be added to the JSON response.
"""
Expand Down
42 changes: 28 additions & 14 deletions tests/rest/client/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,22 @@
# limitations under the License.

from http import HTTPStatus
from typing import Any, Generator, Tuple, cast
from unittest.mock import Mock, call

from twisted.internet import defer, reactor
from twisted.internet import defer, reactor as _reactor

from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock

from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import MockClock

reactor = cast(ISynapseReactor, _reactor)


class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self) -> None:
Expand All @@ -34,11 +38,13 @@ def setUp(self) -> None:
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)

self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
Comment on lines -37 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is just making the mock response match what we expect a RestServlet method to return?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fetch_or_execute requires an (int, JsonDict) I believe. So yeah it seemed easiest to just give a more accurate value.

self.mock_key = "foo"

@defer.inlineCallbacks
def test_executes_given_function(self):
def test_executes_given_function(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this signature means that

  • res: Any?
  • no constraints on the type of the thing being yielded

That's ever so slightly sad, but I don't think there's a good way to get mypy to check this short of using proper async defs. (And I think that is a little awkward to do in tests?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It unfortunately needs to match what deferred.inlineCallbacks wants, I think. And the type on that isn't super accurate IIRC?

cb = Mock(return_value=make_awaitable(self.mock_http_response))
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg"
Expand All @@ -47,7 +53,9 @@ def test_executes_given_function(self):
self.assertEqual(res, self.mock_http_response)

@defer.inlineCallbacks
def test_deduplicates_based_on_key(self):
def test_deduplicates_based_on_key(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute(
Expand All @@ -58,18 +66,20 @@ def test_deduplicates_based_on_key(self):
cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0)

@defer.inlineCallbacks
def test_logcontexts_with_async_result(self):
def test_logcontexts_with_async_result(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
@defer.inlineCallbacks
def cb():
def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]:
yield Clock(reactor).sleep(0)
return "yay"
return 1, {}

@defer.inlineCallbacks
def test():
def test() -> Generator["defer.Deferred[Any]", object, None]:
with LoggingContext("c") as c1:
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertIs(current_context(), c1)
self.assertEqual(res, "yay")
self.assertEqual(res, (1, {}))

# run the test twice in parallel
d = defer.gatherResults([test(), test()])
Expand All @@ -78,13 +88,15 @@ def test():
self.assertIs(current_context(), SENTINEL_CONTEXT)

@defer.inlineCallbacks
def test_does_not_cache_exceptions(self):
def test_does_not_cache_exceptions(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
"""Checks that, if the callback throws an exception, it is called again
for the next request.
"""
called = [False]

def cb():
def cb() -> "defer.Deferred[Tuple[int, JsonDict]]":
if called[0]:
# return a valid result the second time
return defer.succeed(self.mock_http_response)
Expand All @@ -104,13 +116,15 @@ def cb():
self.assertIs(current_context(), test_context)

@defer.inlineCallbacks
def test_does_not_cache_failures(self):
def test_does_not_cache_failures(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
"""Checks that, if the callback returns a failure, it is called again
for the next request.
"""
called = [False]

def cb():
def cb() -> "defer.Deferred[Tuple[int, JsonDict]]":
if called[0]:
# return a valid result the second time
return defer.succeed(self.mock_http_response)
Expand All @@ -130,7 +144,7 @@ def cb():
self.assertIs(current_context(), test_context)

@defer.inlineCallbacks
def test_cleans_up(self):
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
# should NOT have cleaned up yet
Expand Down