Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Reduce the number of "untyped defs" #12716

Merged
merged 14 commits into from
May 12, 2022
15 changes: 15 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,18 @@ disallow_untyped_defs = True
[mypy-synapse.federation.transport.client]
disallow_untyped_defs = False

[mypy-synapse.groups.*]
disallow_untyped_defs = True

[mypy-synapse.handlers.*]
disallow_untyped_defs = True

[mypy-synapse.http.federation.*]
disallow_untyped_defs = True

[mypy-synapse.http.request_metrics]
disallow_untyped_defs = True

[mypy-synapse.http.server]
disallow_untyped_defs = True

Expand Down Expand Up @@ -202,6 +211,12 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.user_erasure_store]
disallow_untyped_defs = True

[mypy-synapse.storage.persist_events]
disallow_untyped_defs = True

[mypy-synapse.storage.types]
disallow_untyped_defs = True

[mypy-synapse.storage.util.*]
disallow_untyped_defs = True

Expand Down
2 changes: 1 addition & 1 deletion synapse/groups/groups_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ async def delete_group(self, group_id: str, requester_user_id: str) -> None:
# Before deleting the group lets kick everyone out of it
users = await self.store.get_users_in_group(group_id, include_private=True)

async def _kick_user_from_group(user_id):
async def _kick_user_from_group(user_id: str) -> None:
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
assert isinstance(
Expand Down
16 changes: 10 additions & 6 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IDelayedCall,
IHostResolution,
IReactorPluggableNameResolver,
IReactorTime,
IResolutionReceiver,
ITCPTransport,
)
Expand Down Expand Up @@ -121,13 +123,15 @@ def check_against_blacklist(
_EPSILON = 0.00000001


def _make_scheduler(reactor):
def _make_scheduler(
reactor: IReactorTime,
) -> Callable[[Callable[[], object]], IDelayedCall]:
"""Makes a schedular suitable for a Cooperator using the given reactor.

(This is effectively just a copy from `twisted.internet.task`)
"""

def _scheduler(x):
def _scheduler(x: Callable[[], object]) -> IDelayedCall:
return reactor.callLater(_EPSILON, x)

return _scheduler
Expand Down Expand Up @@ -775,7 +779,7 @@ async def get_file(
)


def _timeout_to_request_timed_out_error(f: Failure):
def _timeout_to_request_timed_out_error(f: Failure) -> Failure:
if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
# The TCP connection has its own timeout (set by the 'connectTimeout' param
# on the Agent), which raises twisted_error.TimeoutError exception.
Expand Down Expand Up @@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred

def _maybe_fail(self):
def _maybe_fail(self) -> None:
"""
Report a max size exceed error and disconnect the first time this is called.
"""
Expand Down Expand Up @@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
Do not use this since it allows an attacker to intercept your communications.
"""

def __init__(self):
def __init__(self) -> None:
self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: False)

def getContext(self, hostname=None, port=None):
return self._context

def creatorForNetloc(self, hostname, port):
def creatorForNetloc(self, hostname: bytes, port: int):
return self
Comment on lines +947 to 948
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is supposed to return IOpenSSLClientConnectionCreator, but as far as I can see InsecureInterceptableContextFactory doesn't implement that interface. I chose to ignore this because this is meant to be for testing only.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I ran into this too, there's a bit of funkiness in some of these interfaces that they don't all seem to be accurate.

2 changes: 1 addition & 1 deletion synapse/http/federation/matrix_federation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(

self._srv_resolver = srv_resolver

def endpointForURI(self, parsed_uri: URI):
def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint":
return MatrixHostnameEndpoint(
self._reactor,
self._proxy_reactor,
Expand Down
4 changes: 2 additions & 2 deletions synapse/http/federation/srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import random
import time
from typing import Callable, Dict, List
from typing import Any, Callable, Dict, List

import attr

Expand Down Expand Up @@ -109,7 +109,7 @@ class SrvResolver:

def __init__(
self,
dns_client=client,
dns_client: Any = client,
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
get_time: Callable[[], float] = time.time,
):
Expand Down
4 changes: 2 additions & 2 deletions synapse/http/federation/well_known_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@

@attr.s(slots=True, frozen=True)
class WellKnownLookupResult:
delegated_server = attr.ib()
delegated_server: Optional[bytes] = attr.ib()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we auto_attribs=True, while we're here? 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll chuck this in before merging.



class WellKnownResolver:
Expand Down Expand Up @@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
class _FetchWellKnownFailure(Exception):
# True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
# a temporary failure.
temporary = attr.ib()
temporary: bool = attr.ib()
31 changes: 19 additions & 12 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from io import BytesIO, StringIO
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Dict,
Generic,
Expand All @@ -44,7 +46,7 @@
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
Expand All @@ -58,11 +60,13 @@
RequestSendFailed,
SynapseError,
)
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
BodyExceededMaxSize,
ByteWriteable,
_make_scheduler,
encode_query_args,
read_body_with_max_size,
)
Expand Down Expand Up @@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):

CONTENT_TYPE = "application/json"

def __init__(self):
def __init__(self) -> None:
self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer)

Expand Down Expand Up @@ -299,7 +303,9 @@ async def _handle_response(
class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly."""

def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"):
def __init__(
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
):
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file

Expand All @@ -317,7 +323,11 @@ class MatrixFederationHttpClient:
requests.
"""

def __init__(self, hs: "HomeServer", tls_client_options_factory):
def __init__(
self,
hs: "HomeServer",
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
):
self.hs = hs
self.signing_key = hs.signing_key
self.server_name = hs.hostname
Expand Down Expand Up @@ -348,10 +358,7 @@ def __init__(self, hs: "HomeServer", tls_client_options_factory):
self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout = 60

def schedule(x):
self.reactor.callLater(_EPSILON, x)

self._cooperator = Cooperator(scheduler=schedule)
self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor))

self._sleeper = AwakenableSleeper(self.reactor)

Expand All @@ -364,7 +371,7 @@ async def _send_request_with_optional_trailing_slash(
self,
request: MatrixFederationRequest,
try_trailing_slash_on_400: bool = False,
**send_request_args,
**send_request_args: Any,
) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
Expand Down Expand Up @@ -1159,7 +1166,7 @@ async def get_file(
self,
destination: str,
path: str,
output_stream,
output_stream: BinaryIO,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
Expand Down Expand Up @@ -1250,10 +1257,10 @@ async def get_file(
return length, headers


def _flatten_response_never_received(e):
def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"):
reasons = ", ".join(
_flatten_response_never_received(f.value) for f in e.reasons
_flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined]
)

return "%s:[%s]" % (type(e).__name__, reasons)
Expand Down
10 changes: 5 additions & 5 deletions synapse/http/request_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def start(self, time_sec: float, name: str, method: str) -> None:
with _in_flight_requests_lock:
_in_flight_requests.add(self)

def stop(self, time_sec, response_code, sent_bytes):
def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None:
with _in_flight_requests_lock:
_in_flight_requests.discard(self)

Expand All @@ -186,13 +186,13 @@ def stop(self, time_sec, response_code, sent_bytes):
)
return

response_code = str(response_code)
response_code_str = str(response_code)

outgoing_responses_counter.labels(self.method, response_code).inc()
outgoing_responses_counter.labels(self.method, response_code_str).inc()

response_count.labels(self.method, self.name, tag).inc()

response_timer.labels(self.method, self.name, tag, response_code).observe(
response_timer.labels(self.method, self.name, tag, response_code_str).observe(
time_sec - self.start_ts
)

Expand Down Expand Up @@ -221,7 +221,7 @@ def stop(self, time_sec, response_code, sent_bytes):
# flight.
self.update_metrics()

def update_metrics(self):
def update_metrics(self) -> None:
"""Updates the in flight metrics with values from this request."""
if not self.start_context:
logger.error(
Expand Down
21 changes: 13 additions & 8 deletions synapse/storage/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Collection,
Deque,
Dict,
Generator,
Generic,
Iterable,
List,
Expand Down Expand Up @@ -207,7 +208,7 @@ async def add_to_queue(

return res

def _handle_queue(self, room_id):
def _handle_queue(self, room_id: str) -> None:
"""Attempts to handle the queue for a room if not already being handled.

The queue's callback will be invoked with for each item in the queue,
Expand All @@ -227,7 +228,7 @@ def _handle_queue(self, room_id):

self._currently_persisting_rooms.add(room_id)

async def handle_queue_loop():
async def handle_queue_loop() -> None:
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
Expand All @@ -250,15 +251,17 @@ async def handle_queue_loop():
with PreserveLoggingContext():
item.deferred.callback(ret)
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
self._event_persist_queues[room_id] = queue
remaining_queue = self._event_persist_queues.pop(room_id, None)
Copy link
Contributor Author

@DMRobertson DMRobertson May 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't sure of the best name here. Mypy doesn't like it because it has a different type to the queue on line 233

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably fine, other options: to_persist_queue?

if remaining_queue:
self._event_persist_queues[room_id] = remaining_queue
self._currently_persisting_rooms.discard(room_id)

# set handle_queue_loop off in the background
run_as_background_process("persist_events", handle_queue_loop)

def _get_drainining_queue(self, room_id):
def _get_drainining_queue(
self, room_id: str
) -> Generator[_EventPersistQueueItem, None, None]:
queue = self._event_persist_queues.setdefault(room_id, deque())

try:
Expand Down Expand Up @@ -317,7 +320,9 @@ async def persist_events(
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))

async def enqueue(item):
async def enqueue(
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
) -> Dict[str, str]:
room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
Expand Down Expand Up @@ -1102,7 +1107,7 @@ async def _is_server_still_joined(

return False

async def _handle_potentially_left_users(self, user_ids: Set[str]):
async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
"""Given a set of remote users check if the server still shares a room with
them. If not then mark those users' device cache as stale.
"""
Expand Down
10 changes: 8 additions & 2 deletions synapse/storage/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from types import TracebackType
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union

from typing_extensions import Protocol

Expand Down Expand Up @@ -86,5 +87,10 @@ def rollback(self) -> None:
def __enter__(self) -> "Connection":
...

def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
...