From 1c58fa119521b54c20d4fe0daff6f75d78820b3e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 10:00:11 -0500 Subject: [PATCH 01/20] Do not ignore tests/server.py. --- mypy.ini | 2 -- 1 file changed, 2 deletions(-) diff --git a/mypy.ini b/mypy.ini index ff6e04b12fe0..94562d0bce10 100644 --- a/mypy.ini +++ b/mypy.ini @@ -31,8 +31,6 @@ exclude = (?x) |synapse/storage/databases/__init__.py |synapse/storage/databases/main/cache.py |synapse/storage/schema/ - - |tests/server.py )$ [mypy-synapse.federation.transport.client] From 3db9e3ae05398d5b0dfc415fa5427981e157b1d9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 08:57:58 -0500 Subject: [PATCH 02/20] Clean-up type hints for FakeTransport. --- .../test_matrix_federation_agent.py | 5 ++- tests/http/test_proxyagent.py | 5 ++- tests/server.py | 43 ++++++++++--------- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index d27422515c8f..8d0a2d56aae9 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -30,7 +30,7 @@ IOpenSSLClientConnectionCreator, IProtocolFactory, ) -from twisted.internet.protocol import Factory +from twisted.internet.protocol import Factory, Protocol from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web._newclient import ResponseNeverReceived from twisted.web.client import Agent @@ -466,7 +466,10 @@ def _do_get_via_proxy( else: assert isinstance(proxy_server_transport, FakeTransport) client_protocol = proxy_server_transport.other + assert isinstance(client_protocol, Protocol) c2s_transport = client_protocol.transport + assert c2s_transport is not None + assert isinstance(c2s_transport, FakeTransport) c2s_transport.other = server_ssl_protocol self.reactor.advance(0) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index 22fdc7f5f23f..ba6e09bbf0b3 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -28,7 +28,7 @@ _WrappingProtocol, ) from twisted.internet.interfaces import IProtocol, IProtocolFactory -from twisted.internet.protocol import Factory +from twisted.internet.protocol import Factory, Protocol from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel @@ -644,7 +644,10 @@ def _do_https_request_via_proxy( else: assert isinstance(proxy_server_transport, FakeTransport) client_protocol = proxy_server_transport.other + assert isinstance(client_protocol, Protocol) c2s_transport = client_protocol.transport + assert c2s_transport is not None + assert isinstance(c2s_transport, FakeTransport) c2s_transport.other = server_ssl_protocol self.reactor.advance(0) diff --git a/tests/server.py b/tests/server.py index 237bcad8ba6f..8e3da150c823 100644 --- a/tests/server.py +++ b/tests/server.py @@ -28,6 +28,7 @@ List, MutableMapping, Optional, + Sequence, Tuple, Type, Union, @@ -573,7 +574,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: @implementer(ITransport) -@attr.s(cmp=False) +@attr.s(cmp=False, auto_attribs=True) class FakeTransport: """ A twisted.internet.interfaces.ITransport implementation which sends all its data @@ -588,35 +589,35 @@ class FakeTransport: If you want bidirectional communication, you'll need two instances. """ - other = attr.ib() + other: IProtocol """The Protocol object which will receive any data written to this transport. :type: twisted.internet.interfaces.IProtocol """ - _reactor = attr.ib() + _reactor: IReactorTime """Test reactor :type: twisted.internet.interfaces.IReactorTime """ - _protocol = attr.ib(default=None) + _protocol: Optional[IProtocol] = None """The Protocol which is producing data for this transport. Optional, but if set will get called back for connectionLost() notifications etc. """ - _peer_address: Optional[IAddress] = attr.ib(default=None) + _peer_address: Optional[IAddress] = None """The value to be returned by getPeer""" - _host_address: Optional[IAddress] = attr.ib(default=None) + _host_address: Optional[IAddress] = None """The value to be returned by getHost""" disconnecting = False disconnected = False connected = True - buffer = attr.ib(default=b"") - producer = attr.ib(default=None) - autoflush = attr.ib(default=True) + buffer: bytes = attr.Factory(bytes) + producer: Optional[IPushProducer] = None + autoflush: bool = True def getPeer(self) -> Optional[IAddress]: return self._peer_address @@ -624,12 +625,12 @@ def getPeer(self) -> Optional[IAddress]: def getHost(self) -> Optional[IAddress]: return self._host_address - def loseConnection(self, reason=None): + def loseConnection(self, reason: Optional[Failure] = None) -> None: if not self.disconnecting: logger.info("FakeTransport: loseConnection(%s)", reason) self.disconnecting = True if self._protocol: - self._protocol.connectionLost(reason) + self._protocol.connectionLost(reason) # type: ignore[arg-type] # if we still have data to write, delay until that is done if self.buffer: @@ -640,38 +641,38 @@ def loseConnection(self, reason=None): self.connected = False self.disconnected = True - def abortConnection(self): + def abortConnection(self) -> None: logger.info("FakeTransport: abortConnection()") if not self.disconnecting: self.disconnecting = True if self._protocol: - self._protocol.connectionLost(None) + self._protocol.connectionLost(None) # type: ignore[arg-type] self.disconnected = True - def pauseProducing(self): + def pauseProducing(self) -> None: if not self.producer: return self.producer.pauseProducing() - def resumeProducing(self): + def resumeProducing(self) -> None: if not self.producer: return self.producer.resumeProducing() - def unregisterProducer(self): + def unregisterProducer(self) -> None: if not self.producer: return self.producer = None - def registerProducer(self, producer, streaming): + def registerProducer(self, producer: IPushProducer, streaming: bool) -> None: self.producer = producer self.producerStreaming = streaming - def _produce(): + def _produce() -> None: if not self.producer: # we've been unregistered return @@ -683,7 +684,7 @@ def _produce(): if not streaming: self._reactor.callLater(0.0, _produce) - def write(self, byt): + def write(self, byt: bytes) -> None: if self.disconnecting: raise Exception("Writing to disconnecting FakeTransport") @@ -695,11 +696,11 @@ def write(self, byt): if self.autoflush: self._reactor.callLater(0.0, self.flush) - def writeSequence(self, seq): + def writeSequence(self, seq: Iterable[bytes]) -> None: for x in seq: self.write(x) - def flush(self, maxbytes=None): + def flush(self, maxbytes: Optional[int] = None) -> None: if not self.buffer: # nothing to do. Don't write empty buffers: it upsets the # TLSMemoryBIOProtocol From 721c05290494aa0679b5eb04b00aa204e9d499a6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 08:59:22 -0500 Subject: [PATCH 03/20] Add missing type hints to FakeChannel and FakeSite. --- tests/server.py | 50 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/tests/server.py b/tests/server.py index 8e3da150c823..2bd0dfdeb615 100644 --- a/tests/server.py +++ b/tests/server.py @@ -32,12 +32,13 @@ Tuple, Type, Union, + cast, ) from unittest.mock import Mock import attr from typing_extensions import Deque -from zope.interface import implementer +from zope.interface import implementer, providedBy from twisted.internet import address, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier @@ -47,6 +48,7 @@ IAddress, IConsumer, IHostnameResolver, + IProducer, IProtocol, IPullProducer, IPushProducer, @@ -99,12 +101,14 @@ class TimedOutException(Exception): """ -@implementer(IConsumer) +@implementer(ITransport, IPushProducer, IConsumer) @attr.s(auto_attribs=True) class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). + + See twisted.web.http.HTTPChannel. """ site: Union[Site, "FakeSite"] @@ -143,7 +147,7 @@ def text_body(self) -> str: Raises an exception if the request has not yet completed. """ - if not self.is_finished: + if not self.is_finished(): raise Exception("Request not yet completed") return self.result["body"].decode("utf8") @@ -166,27 +170,35 @@ def headers(self) -> Headers: h.addRawHeader(*i) return h - def writeHeaders(self, version, code, reason, headers): + def writeHeaders( + self, version: bytes, code: bytes, reason: bytes, headers: Headers + ) -> None: self.result["version"] = version self.result["code"] = code self.result["reason"] = reason self.result["headers"] = headers - def write(self, content: bytes) -> None: - assert isinstance(content, bytes), "Should be bytes! " + repr(content) + def write(self, data: bytes) -> None: + assert isinstance(data, bytes), "Should be bytes! " + repr(data) if "body" not in self.result: self.result["body"] = b"" - self.result["body"] += content + self.result["body"] += data + + def writeSequence(self, data: Iterable[bytes]) -> None: + for x in data: + self.write(x) + + def loseConnection(self) -> None: + self.unregisterProducer() + self.transport.loseConnection() # Type ignore: mypy doesn't like the fact that producer isn't an IProducer. - def registerProducer( # type: ignore[override] - self, - producer: Union[IPullProducer, IPushProducer], - streaming: bool, - ) -> None: - self._producer = producer + def registerProducer(self, producer: IProducer, streaming: bool) -> None: + # Ensure that the producer implements one or more of IPushProducer and IPullProducer. + assert not set(providedBy(producer)).isdisjoint({IPushProducer, IPullProducer}) + self._producer = cast(Union[IPushProducer, IPullProducer], producer) self.producerStreaming = streaming def _produce() -> None: @@ -203,6 +215,16 @@ def unregisterProducer(self) -> None: self._producer = None + def stopProducing(self) -> None: + if self._producer is not None: + self._producer.stopProducing() + + def pauseProducing(self) -> None: + raise NotImplementedError() + + def resumeProducing(self) -> None: + raise NotImplementedError() + def requestDone(self, _self: Request) -> None: self.result["done"] = True if isinstance(_self, SynapseRequest): @@ -282,7 +304,7 @@ def __init__( self.reactor = reactor self.experimental_cors_msc3886 = experimental_cors_msc3886 - def getResourceFor(self, request): + def getResourceFor(self, request: Request) -> IResource: return self._resource From a171830bef0958355de1bc7100298fa7ccf1eea7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 09:20:59 -0500 Subject: [PATCH 04/20] Fix-up ThreadedMemoryReactorClock. --- tests/server.py | 59 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 12 deletions(-) diff --git a/tests/server.py b/tests/server.py index 2bd0dfdeb615..e09e86ccfd98 100644 --- a/tests/server.py +++ b/tests/server.py @@ -22,6 +22,7 @@ from collections import deque from io import SEEK_END, BytesIO from typing import ( + Any, Callable, Dict, Iterable, @@ -31,6 +32,7 @@ Sequence, Tuple, Type, + TypeVar, Union, cast, ) @@ -46,6 +48,7 @@ from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( IAddress, + IConnector, IConsumer, IHostnameResolver, IProducer, @@ -57,6 +60,8 @@ IResolverSimple, ITransport, ) +from twisted.internet.protocol import ClientFactory, DatagramProtocol +from twisted.python import threadpool from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http_headers import Headers @@ -91,6 +96,8 @@ logger = logging.getLogger(__name__) +R = TypeVar("R") + # the type of thing that can be passed into `make_request` in the headers list CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]] @@ -432,19 +439,21 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): A MemoryReactorClock that supports callFromThread. """ - def __init__(self): + def __init__(self) -> None: self.threadpool = ThreadPool(self) self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {} - self._udp = [] + self._udp: List[udp.Port] = [] self.lookups: Dict[str, str] = {} - self._thread_callbacks: Deque[Callable[[], None]] = deque() + self._thread_callbacks: Deque[Callable[..., R]] = deque() lookups = self.lookups @implementer(IResolverSimple) class FakeResolver: - def getHostByName(self, name, timeout=None): + def getHostByName( + self, name: str, timeout: Optional[Sequence[int]] = None + ) -> "Deferred[str]": if name not in lookups: return fail(DNSLookupError("OH NO: unknown %s" % (name,))) return succeed(lookups[name]) @@ -455,25 +464,44 @@ def getHostByName(self, name, timeout=None): def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: raise NotImplementedError() - def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): + def listenUDP( + self, + port: int, + protocol: DatagramProtocol, + interface: str = "", + maxPacketSize: int = 8196, + ) -> udp.Port: p = udp.Port(port, protocol, interface, maxPacketSize, self) p.startListening() self._udp.append(p) return p - def callFromThread(self, callback, *args, **kwargs): + def callFromThread( + self, callable: Callable[..., Any], *args: object, **kwargs: object + ) -> None: """ Make the callback fire in the next reactor iteration. """ - cb = lambda: callback(*args, **kwargs) + cb = lambda: callable(*args, **kwargs) # it's not safe to call callLater() here, so we append the callback to a # separate queue. self._thread_callbacks.append(cb) - def getThreadPool(self): - return self.threadpool + def callInThread( + self, callable: Callable[..., Any], *args: object, **kwargs: object + ) -> None: + raise NotImplementedError() + + def suggestThreadPoolSize(self, size: int) -> None: + raise NotImplementedError() + + def getThreadPool(self) -> "threadpool.ThreadPool": + # Cast to match super-class. + return cast(threadpool.ThreadPool, self.threadpool) - def add_tcp_client_callback(self, host: str, port: int, callback: Callable): + def add_tcp_client_callback( + self, host: str, port: int, callback: Callable[[], None] + ) -> None: """Add a callback that will be invoked when we receive a connection attempt to the given IP/port using `connectTCP`. @@ -482,7 +510,14 @@ def add_tcp_client_callback(self, host: str, port: int, callback: Callable): """ self._tcp_callbacks[(host, port)] = callback - def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None): + def connectTCP( + self, + host: str, + port: int, + factory: ClientFactory, + timeout: float = 30, + bindAddress: Optional[Tuple[str, int]] = None, + ) -> IConnector: """Fake L{IReactorTCP.connectTCP}.""" conn = super().connectTCP( @@ -495,7 +530,7 @@ def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None return conn - def advance(self, amount): + def advance(self, amount: float) -> None: # first advance our reactor's time, and run any "callLater" callbacks that # makes ready super().advance(amount) From 1cabd46df2da8f14c377bbe5723a3f57e3fe8ef9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 09:22:55 -0500 Subject: [PATCH 05/20] Clean-up ThreadPool. --- tests/server.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/server.py b/tests/server.py index e09e86ccfd98..274119401327 100644 --- a/tests/server.py +++ b/tests/server.py @@ -39,7 +39,7 @@ from unittest.mock import Mock import attr -from typing_extensions import Deque +from typing_extensions import Deque, ParamSpec from zope.interface import implementer, providedBy from twisted.internet import address, threads, udp @@ -97,6 +97,7 @@ logger = logging.getLogger(__name__) R = TypeVar("R") +P = ParamSpec("P") # the type of thing that can be passed into `make_request` in the headers list CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]] @@ -558,25 +559,33 @@ def advance(self, amount: float) -> None: class ThreadPool: """ Threadless thread pool. + + See twisted.python.threadpool.ThreadPool """ - def __init__(self, reactor): + def __init__(self, reactor: IReactorTime): self._reactor = reactor - def start(self): + def start(self) -> None: pass - def stop(self): + def stop(self) -> None: pass - def callInThreadWithCallback(self, onResult, function, *args, **kwargs): - def _(res): + def callInThreadWithCallback( + self, + onResult: Callable[[bool, Union[Failure, R]], None], + function: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> "Deferred[None]": + def _(res: Any) -> None: if isinstance(res, Failure): onResult(False, res) else: onResult(True, res) - d = Deferred() + d: "Deferred[None]" = Deferred() d.addCallback(lambda x: function(*args, **kwargs)) d.addBoth(_) self._reactor.callLater(0, d.callback, True) From aa5842b9e2a6e9aa980ea652194b313e3b3272e1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 09:29:15 -0500 Subject: [PATCH 06/20] Fix-up _make_test_homeserver_synchronous. --- tests/server.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/server.py b/tests/server.py index 274119401327..bae67f722506 100644 --- a/tests/server.py +++ b/tests/server.py @@ -602,30 +602,48 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None: for database in server.get_datastores().databases: pool = database._db_pool - def runWithConnection(func, *args, **kwargs): - return threads.deferToThreadPool( + async def runWithConnection( + func: Callable[..., R], + *args: Any, + db_autocommit: bool = False, + isolation_level: Optional[int] = None, + **kwargs: Any, + ) -> R: + return await threads.deferToThreadPool( pool._reactor, pool.threadpool, pool._runWithConnection, func, *args, + db_autocommit, + isolation_level, **kwargs, ) - def runInteraction(interaction, *args, **kwargs): - return threads.deferToThreadPool( + async def runInteraction( + desc: str, + func: Callable[..., R], + *args: Any, + db_autocommit: bool = False, + isolation_level: Optional[int] = None, + **kwargs: Any, + ) -> R: + return await threads.deferToThreadPool( pool._reactor, pool.threadpool, pool._runInteraction, - interaction, + desc, + func, *args, + db_autocommit, + isolation_level, **kwargs, ) - pool.runWithConnection = runWithConnection - pool.runInteraction = runInteraction + pool.runWithConnection = runWithConnection # type: ignore[assignment] + pool.runInteraction = runInteraction # type: ignore[assignment] # Replace the thread pool with a threadless 'thread' pool - pool.threadpool = ThreadPool(clock._reactor) + pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment] pool.running = True # We've just changed the Databases to run DB transactions on the same From 954b7489e91b6eb02d1c571c15c929f8dae3e92e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 09:53:38 -0500 Subject: [PATCH 07/20] Fix-up setup_test_homeserver. --- tests/server.py | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/tests/server.py b/tests/server.py index bae67f722506..3e173243d1cf 100644 --- a/tests/server.py +++ b/tests/server.py @@ -69,6 +69,7 @@ from twisted.web.server import Request, Site from synapse.config.database import DatabaseConnectionConfig +from synapse.config.homeserver import HomeServerConfig from synapse.events.presence_router import load_legacy_presence_router from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules @@ -835,17 +836,17 @@ def connect_client( class TestHomeServer(HomeServer): - DATASTORE_CLASS = DataStore + DATASTORE_CLASS = DataStore # type: ignore[assignment] def setup_test_homeserver( - cleanup_func, - name="test", - config=None, - reactor=None, + cleanup_func: Callable[[Callable[[], None]], None], + name: str = "test", + config: Optional[HomeServerConfig] = None, + reactor: Optional[ISynapseReactor] = None, homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs, -): + **kwargs: Any, +) -> HomeServer: """ Setup a homeserver suitable for running tests against. Keyword arguments are passed to the Homeserver constructor. @@ -860,13 +861,14 @@ def setup_test_homeserver( HomeserverTestCase. """ if reactor is None: - from twisted.internet import reactor + from twisted.internet import reactor as _reactor + + reactor = cast(ISynapseReactor, _reactor) if config is None: config = default_config(name, parse=True) config.caches.resize_all_caches() - config.ldap_enabled = False if "clock" not in kwargs: kwargs["clock"] = MockClock() @@ -917,6 +919,8 @@ def setup_test_homeserver( # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() if isinstance(db_engine, PostgresEngine): + import psycopg2.extensions + db_conn = db_engine.module.connect( database=POSTGRES_BASE_DB, user=POSTGRES_USER, @@ -924,6 +928,7 @@ def setup_test_homeserver( port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) + assert isinstance(db_conn, psycopg2.extensions.connection) db_conn.autocommit = True cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) @@ -952,14 +957,15 @@ def setup_test_homeserver( hs.setup_background_tasks() if isinstance(db_engine, PostgresEngine): - database = hs.get_datastores().databases[0] + database_pool = hs.get_datastores().databases[0] # We need to do cleanup on PostgreSQL - def cleanup(): + def cleanup() -> None: import psycopg2 + import psycopg2.extensions # Close all the db pools - database._db_pool.close() + database_pool._db_pool.close() dropped = False @@ -971,6 +977,7 @@ def cleanup(): port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) + assert isinstance(db_conn, psycopg2.extensions.connection) db_conn.autocommit = True cur = db_conn.cursor() @@ -1003,23 +1010,23 @@ def cleanup(): # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - async def hash(p): + async def hash(p: str) -> str: return hashlib.md5(p.encode("utf8")).hexdigest() - hs.get_auth_handler().hash = hash + hs.get_auth_handler().hash = hash # type: ignore[assignment] - async def validate_hash(p, h): + async def validate_hash(p: str, h: str) -> bool: return hashlib.md5(p.encode("utf8")).hexdigest() == h - hs.get_auth_handler().validate_hash = validate_hash + hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment] # Make the threadpool and database transactions synchronous for testing. _make_test_homeserver_synchronous(hs) # Load any configured modules into the homeserver module_api = hs.get_module_api() - for module, config in hs.config.modules.loaded_modules: - module(config=config, api=module_api) + for module, module_config in hs.config.modules.loaded_modules: + module(config=module_config, api=module_api) load_legacy_spam_checkers(hs) load_legacy_third_party_event_rules(hs) From 43e7d11a3ce06c11ef5f5e348d31f86ac211b942 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 09:55:44 -0500 Subject: [PATCH 08/20] Clean-up more type hints on FakeTransport. --- tests/server.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/server.py b/tests/server.py index 3e173243d1cf..a6929ee8fbac 100644 --- a/tests/server.py +++ b/tests/server.py @@ -691,10 +691,14 @@ class FakeTransport: will get called back for connectionLost() notifications etc. """ - _peer_address: Optional[IAddress] = None + _peer_address: IAddress = attr.Factory( + lambda: address.IPv4Address("TCP", "127.0.0.1", 5678) + ) """The value to be returned by getPeer""" - _host_address: Optional[IAddress] = None + _host_address: IAddress = attr.Factory( + lambda: address.IPv4Address("TCP", "127.0.0.1", 1234) + ) """The value to be returned by getHost""" disconnecting = False @@ -704,10 +708,10 @@ class FakeTransport: producer: Optional[IPushProducer] = None autoflush: bool = True - def getPeer(self) -> Optional[IAddress]: + def getPeer(self) -> IAddress: return self._peer_address - def getHost(self) -> Optional[IAddress]: + def getHost(self) -> IAddress: return self._host_address def loseConnection(self, reason: Optional[Failure] = None) -> None: From e0562d6351241de92722af056f013a10f24f1038 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 10:31:19 -0500 Subject: [PATCH 09/20] Fix-up make_request. --- tests/rest/client/test_auth.py | 14 ++------ tests/rest/client/utils.py | 58 ++++++++++++++++++++++------------ tests/server.py | 2 +- tests/unittest.py | 17 ++++++++-- 4 files changed, 57 insertions(+), 34 deletions(-) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index f4e1e7de4352..a1446100780c 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -34,7 +34,7 @@ from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER -from tests.server import FakeChannel, make_request +from tests.server import FakeChannel from tests.unittest import override_config, skip_unless @@ -1322,16 +1322,8 @@ def test_logout_during_login(self) -> None: channel = self.submit_logout_token(logout_token) self.assertEqual(channel.code, 200) - # Now try to exchange the login token - channel = make_request( - self.hs.get_reactor(), - self.site, - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - # It should have failed - self.assertEqual(channel.code, 403) + # Now try to exchange the login token, it should fail. + self.helper.login_via_token(login_token, 403) @override_config( { diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 8d6f2b6ff9cc..9532e5ddc102 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -36,6 +36,7 @@ import attr from typing_extensions import Literal +from twisted.test.proto_helpers import MemoryReactorClock from twisted.web.resource import Resource from twisted.web.server import Site @@ -67,6 +68,7 @@ class RestHelper: """ hs: HomeServer + reactor: MemoryReactorClock site: Site auth_user_id: Optional[str] @@ -142,7 +144,7 @@ def create_room_as( path = path + "?access_token=%s" % tok channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", path, @@ -216,7 +218,7 @@ def knock( data["reason"] = reason channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", path, @@ -313,7 +315,7 @@ def change_membership( data.update(extra_data or {}) channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "PUT", path, @@ -394,7 +396,7 @@ def send_event( path = path + "?access_token=%s" % tok channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "PUT", path, @@ -433,7 +435,7 @@ def get_event( path = path + f"?access_token={tok}" channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", path, @@ -488,7 +490,7 @@ def _read_write_state( if body is not None: content = json.dumps(body).encode("utf8") - channel = make_request(self.hs.get_reactor(), self.site, method, path, content) + channel = make_request(self.reactor, self.site, method, path, content) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, @@ -573,8 +575,8 @@ def upload_media( image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) channel = make_request( - self.hs.get_reactor(), - FakeSite(resource, self.hs.get_reactor()), + self.reactor, + FakeSite(resource, self.reactor), "POST", path, content=image_data, @@ -603,7 +605,7 @@ def whoami( expect_code: The return code to expect from attempting the whoami request """ channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", "account/whoami", @@ -642,7 +644,7 @@ def login_via_oidc( ) -> Tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC - Returns the result of the final token login. + Returns the result of the final token login and the fake authorization grant. Requires that "oidc_config" in the homeserver config be set appropriately (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a @@ -672,10 +674,28 @@ def login_via_oidc( assert m, channel.text_body login_token = m.group(1) - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token and device id. + return self.login_via_token(login_token, expected_status), grant + + def login_via_token( + self, + login_token: str, + expected_status: int = 200, + ) -> JsonDict: + """Submit the matrix login token to the login API, which gives us our + matrix access token and device id.Log in (as a new user) via OIDC + + Returns the result of the token login. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + """ + channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", "/login", @@ -684,7 +704,7 @@ def login_via_oidc( assert ( channel.code == expected_status ), f"unexpected status in response: {channel.code}" - return channel.json_body, grant + return channel.json_body def auth_via_oidc( self, @@ -805,7 +825,7 @@ def complete_oidc_auth( with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", callback_uri, @@ -849,7 +869,7 @@ def initiate_sso_login( # is the easiest way of figuring out what the Host header ought to be set to # to keep Synapse happy. channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", uri, @@ -867,7 +887,7 @@ def get_location(channel: FakeChannel) -> str: location = get_location(channel) parts = urllib.parse.urlsplit(location) channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", urllib.parse.urlunsplit(("", "") + parts[2:]), @@ -900,9 +920,7 @@ def initiate_sso_ui_auth( + urllib.parse.urlencode({"session": ui_auth_session_id}) ) # hit the redirect url (which will issue a cookie and state) - channel = make_request( - self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint - ) + channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint) # that should serve a confirmation page assert channel.code == HTTPStatus.OK, channel.text_body channel.extract_cookies(cookies) diff --git a/tests/server.py b/tests/server.py index a6929ee8fbac..6bed73148de9 100644 --- a/tests/server.py +++ b/tests/server.py @@ -318,7 +318,7 @@ def getResourceFor(self, request: Request) -> IResource: def make_request( - reactor, + reactor: MemoryReactorClock, site: Union[Site, FakeSite], method: Union[bytes, str], path: Union[bytes, str], diff --git a/tests/unittest.py b/tests/unittest.py index c1cb5933faed..d995b5b1c5bb 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -34,6 +34,7 @@ Type, TypeVar, Union, + cast, ) from unittest.mock import Mock, patch @@ -45,7 +46,7 @@ from twisted.internet.defer import Deferred, ensureDeferred from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool -from twisted.test.proto_helpers import MemoryReactor +from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock from twisted.trial import unittest from twisted.web.resource import Resource from twisted.web.server import Request @@ -296,7 +297,19 @@ def setUp(self) -> None: from tests.rest.client.utils import RestHelper - self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) + # HomeServer's reactor is ISynapseReactor, but for tests it should be + # MemoryReactorClock, which some of the internal mechanisms of tests + # depend on. + # + # Attempting to assert that here causes mypy to think the rest the code + # below the assertion to be unreachable, so just cast it. Hopefully this + # is true! + self.helper = RestHelper( + self.hs, + cast(MemoryReactorClock, self.hs.get_reactor()), + self.site, + getattr(self, "user_id", None), + ) if hasattr(self, "user_id"): if self.hijack_auth: From d3cbcba3ba840a62c48336e706cc8997b25f5618 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 13:56:35 -0500 Subject: [PATCH 10/20] Newsfragment --- changelog.d/15084.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/15084.misc diff --git a/changelog.d/15084.misc b/changelog.d/15084.misc new file mode 100644 index 000000000000..93ceaeafc9b9 --- /dev/null +++ b/changelog.d/15084.misc @@ -0,0 +1 @@ +Improve type hints. From 82987176bcb9a36f36517a0fa2abc639c381d73f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 13:57:11 -0500 Subject: [PATCH 11/20] Fix-up bad import. --- tests/appservice/test_scheduler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index febcc1499d06..a2d248b33382 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -17,6 +17,7 @@ from typing_extensions import TypeAlias from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.appservice import ( ApplicationService, @@ -40,9 +41,6 @@ from ..utils import MockClock -if TYPE_CHECKING: - from twisted.internet.testing import MemoryReactor - class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): def setUp(self) -> None: From 838ace3bb40d2a9b0c91acc9ef177f65afb8463c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 14:37:09 -0500 Subject: [PATCH 12/20] Remove unused import. --- tests/appservice/test_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index a2d248b33382..e2a3bad065da 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast +from typing import List, Optional, Sequence, Tuple, cast from unittest.mock import Mock from typing_extensions import TypeAlias From e7abbd231d2267b66636a758ea15d22a09493df1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Feb 2023 08:30:35 -0500 Subject: [PATCH 13/20] Use checked_cast. --- .../http/federation/test_matrix_federation_agent.py | 4 +--- tests/http/test_proxyagent.py | 4 +--- tests/unittest.py | 12 ++---------- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 8d0a2d56aae9..eb7f53fee502 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -467,9 +467,7 @@ def _do_get_via_proxy( assert isinstance(proxy_server_transport, FakeTransport) client_protocol = proxy_server_transport.other assert isinstance(client_protocol, Protocol) - c2s_transport = client_protocol.transport - assert c2s_transport is not None - assert isinstance(c2s_transport, FakeTransport) + c2s_transport = checked_cast(FakeTransport, client_protocol.transport) c2s_transport.other = server_ssl_protocol self.reactor.advance(0) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index ba6e09bbf0b3..cc175052ac78 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -645,9 +645,7 @@ def _do_https_request_via_proxy( assert isinstance(proxy_server_transport, FakeTransport) client_protocol = proxy_server_transport.other assert isinstance(client_protocol, Protocol) - c2s_transport = client_protocol.transport - assert c2s_transport is not None - assert isinstance(c2s_transport, FakeTransport) + c2s_transport = checked_cast(FakeTransport, client_protocol.transport) c2s_transport.other = server_ssl_protocol self.reactor.advance(0) diff --git a/tests/unittest.py b/tests/unittest.py index d995b5b1c5bb..b21e7f122196 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -34,7 +34,6 @@ Type, TypeVar, Union, - cast, ) from unittest.mock import Mock, patch @@ -83,7 +82,7 @@ ) from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging -from tests.utils import default_config, setupdb +from tests.utils import checked_cast, default_config, setupdb setupdb() setup_logging() @@ -297,16 +296,9 @@ def setUp(self) -> None: from tests.rest.client.utils import RestHelper - # HomeServer's reactor is ISynapseReactor, but for tests it should be - # MemoryReactorClock, which some of the internal mechanisms of tests - # depend on. - # - # Attempting to assert that here causes mypy to think the rest the code - # below the assertion to be unreachable, so just cast it. Hopefully this - # is true! self.helper = RestHelper( self.hs, - cast(MemoryReactorClock, self.hs.get_reactor()), + checked_cast(MemoryReactorClock, self.hs.get_reactor()), self.site, getattr(self, "user_id", None), ) From 121638aa7f3939477342d1060e484af4bd2c0d97 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Feb 2023 08:31:23 -0500 Subject: [PATCH 14/20] Bytes is immutable. --- tests/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/server.py b/tests/server.py index 6bed73148de9..893906b22634 100644 --- a/tests/server.py +++ b/tests/server.py @@ -704,7 +704,7 @@ class FakeTransport: disconnecting = False disconnected = False connected = True - buffer: bytes = attr.Factory(bytes) + buffer: bytes = b"" producer: Optional[IPushProducer] = None autoflush: bool = True From 591a8af441e6c04bb1ef5a8c9cf949d201b39e1b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Feb 2023 08:31:49 -0500 Subject: [PATCH 15/20] Remove type hints in comments. --- tests/server.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/server.py b/tests/server.py index 893906b22634..e036ebf5e78a 100644 --- a/tests/server.py +++ b/tests/server.py @@ -676,14 +676,10 @@ class FakeTransport: other: IProtocol """The Protocol object which will receive any data written to this transport. - - :type: twisted.internet.interfaces.IProtocol """ _reactor: IReactorTime """Test reactor - - :type: twisted.internet.interfaces.IReactorTime """ _protocol: Optional[IProtocol] = None From 011c5cdcafb2e46cd038064d9ca13ed5da833c49 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Feb 2023 08:44:37 -0500 Subject: [PATCH 16/20] Fix-up thread-pool wrappers. --- tests/server.py | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/tests/server.py b/tests/server.py index e036ebf5e78a..6ced0fab5cb6 100644 --- a/tests/server.py +++ b/tests/server.py @@ -23,6 +23,7 @@ from io import SEEK_END, BytesIO from typing import ( Any, + Awaitable, Callable, Dict, Iterable, @@ -603,41 +604,28 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None: for database in server.get_datastores().databases: pool = database._db_pool - async def runWithConnection( - func: Callable[..., R], - *args: Any, - db_autocommit: bool = False, - isolation_level: Optional[int] = None, - **kwargs: Any, - ) -> R: - return await threads.deferToThreadPool( + def runWithConnection( + func: Callable[..., R], *args: Any, **kwargs: Any + ) -> Awaitable[R]: + return threads.deferToThreadPool( pool._reactor, pool.threadpool, pool._runWithConnection, func, *args, - db_autocommit, - isolation_level, **kwargs, ) - async def runInteraction( - desc: str, - func: Callable[..., R], - *args: Any, - db_autocommit: bool = False, - isolation_level: Optional[int] = None, - **kwargs: Any, - ) -> R: - return await threads.deferToThreadPool( + def runInteraction( + desc: str, func: Callable[..., R], *args: Any, **kwargs: Any + ) -> Awaitable[R]: + return threads.deferToThreadPool( pool._reactor, pool.threadpool, pool._runInteraction, desc, func, *args, - db_autocommit, - isolation_level, **kwargs, ) From f2d22a77a106724c66086794e49213556e937b27 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Feb 2023 08:51:44 -0500 Subject: [PATCH 17/20] Fix-up FakeTransport.loseConnection. --- tests/server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/server.py b/tests/server.py index 6ced0fab5cb6..eb2cd6c0fabc 100644 --- a/tests/server.py +++ b/tests/server.py @@ -698,12 +698,12 @@ def getPeer(self) -> IAddress: def getHost(self) -> IAddress: return self._host_address - def loseConnection(self, reason: Optional[Failure] = None) -> None: + def loseConnection(self) -> None: if not self.disconnecting: - logger.info("FakeTransport: loseConnection(%s)", reason) + logger.info("FakeTransport: loseConnection()") self.disconnecting = True if self._protocol: - self._protocol.connectionLost(reason) # type: ignore[arg-type] + self._protocol.connectionLost(Failure()) # if we still have data to write, delay until that is done if self.buffer: From 8ee38558918971592497ec4f5923a8feeb7d231e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Feb 2023 10:48:35 -0500 Subject: [PATCH 18/20] Remove assertion. --- tests/server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/server.py b/tests/server.py index eb2cd6c0fabc..e54996a25350 100644 --- a/tests/server.py +++ b/tests/server.py @@ -41,7 +41,7 @@ import attr from typing_extensions import Deque, ParamSpec -from zope.interface import implementer, providedBy +from zope.interface import implementer from twisted.internet import address, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier @@ -206,8 +206,9 @@ def loseConnection(self) -> None: # Type ignore: mypy doesn't like the fact that producer isn't an IProducer. def registerProducer(self, producer: IProducer, streaming: bool) -> None: - # Ensure that the producer implements one or more of IPushProducer and IPullProducer. - assert not set(providedBy(producer)).isdisjoint({IPushProducer, IPullProducer}) + # TODO This should ensure that the IProducer is an IPushProducer or + # IPullProducer, unfortunately twisted.protocols.basic.FileSender does + # implement those, but doesn't declare it. self._producer = cast(Union[IPushProducer, IPullProducer], producer) self.producerStreaming = streaming From 98dc5bbd80a67f48df016646088a8d8edfd9b91d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Feb 2023 11:36:24 -0500 Subject: [PATCH 19/20] Generate an exception for the failure. --- tests/server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/server.py b/tests/server.py index e54996a25350..35c4b444f9a5 100644 --- a/tests/server.py +++ b/tests/server.py @@ -43,6 +43,7 @@ from typing_extensions import Deque, ParamSpec from zope.interface import implementer + from twisted.internet import address, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed @@ -704,7 +705,9 @@ def loseConnection(self) -> None: logger.info("FakeTransport: loseConnection()") self.disconnecting = True if self._protocol: - self._protocol.connectionLost(Failure()) + self._protocol.connectionLost( + Failure(RuntimeError("FakeTransport.loseConnection()")) + ) # if we still have data to write, delay until that is done if self.buffer: From b8625d25599020f2a9c76cca34079cb74d85b266 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Feb 2023 12:40:06 -0500 Subject: [PATCH 20/20] lint --- tests/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/server.py b/tests/server.py index 35c4b444f9a5..5de97227669c 100644 --- a/tests/server.py +++ b/tests/server.py @@ -43,7 +43,6 @@ from typing_extensions import Deque, ParamSpec from zope.interface import implementer - from twisted.internet import address, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed