Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type checking #358

Merged
merged 9 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/358.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This change fixes small, simple errors raised by mypy.
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 6 additions & 6 deletions sydent/db/valsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def setSendAttemptNumber(self, sid: int, attemptNo: int) -> None:
Updates the send attempt number for the session with the given ID.

:param sid: The ID of the session to update
:type sid: unicode
:type sid: int
:param attemptNo: The send attempt number to update the session with.
:type attemptNo: int
"""
Expand All @@ -149,7 +149,7 @@ def setValidated(self, sid: int, validated: bool) -> None:
Updates a session to set the validated flag to the given value.

:param sid: The ID of the session to update.
:type sid: unicode
:type sid: int
:param validated: The value to set the validated flag.
:type validated: bool
"""
Expand All @@ -166,7 +166,7 @@ def setMtime(self, sid: int, mtime: int) -> None:
Set the time of the last send attempt for the session with the given ID

:param sid: The ID of the session to update.
:type sid: unicode
:type sid: int
:param mtime: The time of the last send attempt for that session.
:type mtime: int
"""
Expand All @@ -183,7 +183,7 @@ def getSessionById(self, sid: int) -> Optional[ValidationSession]:
Retrieves the session matching the given sid.

:param sid: The ID of the session to retrieve.
:type sid: unicode
:type sid: int

:return: The retrieved session, or None if no session could be found with that
sid.
Expand All @@ -210,7 +210,7 @@ def getTokenSessionById(self, sid: int) -> Optional[ValidationSession]:
Retrieves a validation session using the session's ID.

:param sid: The ID of the session to retrieve.
:type sid: unicode
:type sid: int

:return: The validation session, or None if no session was found with that ID.
:rtype: ValidationSession or None
Expand Down Expand Up @@ -239,7 +239,7 @@ def getValidatedSession(self, sid: int, clientSecret: str) -> ValidationSession:
one passed in.

:param sid: The ID of the session to retrieve.
:type sid: unicode
:type sid: int
:param clientSecret: A client secret to check against the one retrieved from
the database.
:type clientSecret: unicode
Expand Down
1 change: 1 addition & 0 deletions sydent/http/httpcommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def dataReceived(self, data) -> None:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()
H-Shay marked this conversation as resolved.
Show resolved Hide resolved

def connectionLost(self, reason=connectionDone) -> None:
Expand Down
12 changes: 8 additions & 4 deletions sydent/http/matrixfederationagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
import logging
import random
import time
from typing import Generator, Optional
from typing import TYPE_CHECKING, Generator, Optional

import attr
from netaddr import IPAddress # type: ignore
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent
Expand All @@ -34,6 +33,9 @@
from sydent.util import json_decoder
from sydent.util.ttlcache import TTLCache

if TYPE_CHECKING:
from twisted.web.iweb import IBodyProducer
H-Shay marked this conversation as resolved.
Show resolved Hide resolved

# period to cache .well-known results for by default
WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600

Expand Down Expand Up @@ -167,6 +169,7 @@ def request(
else:
headers = headers.copy()

assert headers is not None
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
if not headers.hasHeader(b"host"):
headers.addRawHeader(b"host", res.host_header)

Expand All @@ -189,7 +192,7 @@ def endpointForURI(_uri):
@defer.inlineCallbacks
def _route_matrix_uri(
self, parsed_uri: "URI", lookup_well_known: bool = True
) -> "Deferred":
) -> Generator:
"""Helper for `request`: determine the routing for a Matrix URI

:param parsed_uri: uri to route. Note that it should be parsed with
Expand Down Expand Up @@ -320,6 +323,7 @@ def _get_well_known(self, server_name: bytes) -> Generator:
result, cache_period = yield self._do_get_well_known(server_name)

if cache_period > 0:
assert self._well_known_cache is not None
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
self._well_known_cache.set(server_name, result, cache_period)

defer.returnValue(result)
Expand Down Expand Up @@ -357,7 +361,7 @@ def _do_get_well_known(self, server_name: bytes) -> Generator:

# add some randomness to the TTL to avoid a stampeding herd every hour
# after startup
cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
cache_period: float = WELL_KNOWN_INVALID_CACHE_PERIOD
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
defer.returnValue((None, cache_period))
return
Expand Down
2 changes: 1 addition & 1 deletion sydent/http/srvresolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

logger = logging.getLogger(__name__)

SERVER_CACHE = {}
SERVER_CACHE: Dict[bytes, List["Server"]] = {}


@attr.s
Expand Down
5 changes: 3 additions & 2 deletions sydent/replication/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import configparser
import json
import logging
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, cast

import signedjson.key
import signedjson.sign
Expand Down Expand Up @@ -85,6 +85,7 @@ def pushUpdates(self, sgAssocs: Dict[int, Dict[str, Any]]) -> "Deferred":
"""
globalAssocStore = GlobalAssociationStore(self.sydent)
for localId in sgAssocs:
assert self.lastId is not None
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
if localId > self.lastId:
assocObj = threePidAssocFromDict(sgAssocs[localId])

Expand Down Expand Up @@ -258,7 +259,7 @@ def _pushSuccess(
the status code.
:type updateDeferred: twisted.internet.defer.Deferred
"""
if result.code >= 200 and result.code < 300:
if cast(int, result.code) >= 200 and cast(int, result.code) < 300:
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
updateDeferred.callback(result)
else:
d = readBody(result)
Expand Down
4 changes: 2 additions & 2 deletions sydent/sms/openmarket.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from base64 import b64encode
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, Generator, Optional

from twisted.internet import defer
from twisted.web.http_headers import Headers
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self, sydent: "Sydent") -> None:
@defer.inlineCallbacks
def sendTextSMS(
self, body: Dict, dest: str, source: Optional[Dict[str, str]] = None
) -> None:
) -> Generator:
"""
Sends a text message with the given body to the given MSISDN.

Expand Down
5 changes: 3 additions & 2 deletions sydent/threepid/bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import collections
import logging
import math
from typing import TYPE_CHECKING, Any, Dict, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, Union

import signedjson.sign # type: ignore
from twisted.internet import defer
Expand Down Expand Up @@ -92,6 +92,7 @@ def addBinding(self, medium: str, address: str, mxid: str) -> Dict[str, Any]:
joinTokenStore = JoinTokenStore(self.sydent)
pendingJoinTokens = joinTokenStore.getTokens(medium, address)
invites = []
token: Any
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
for token in pendingJoinTokens:
token["mxid"] = mxid
token["signed"] = {
Expand Down Expand Up @@ -127,7 +128,7 @@ def removeBinding(self, threepid: Dict[str, str], mxid: str) -> None:
self.sydent.pusher.doLocalPush()

@defer.inlineCallbacks
def _notify(self, assoc: Dict[str, Any], attempt: int) -> None:
def _notify(self, assoc: Dict[str, Any], attempt: int) -> Generator:
"""
Sends data about a new association (and, if necessary, the associated invites)
to the associated MXID's homeserver.
Expand Down
2 changes: 1 addition & 1 deletion sydent/util/emailutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def sendEmail(
)
try:
if mailTLSMode == "SSL" or mailTLSMode == "TLS":
smtp = smtplib.SMTP_SSL(mailServer, mailPort, myHostname)
smtp: Any = smtplib.SMTP_SSL(mailServer, mailPort, myHostname)
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
elif mailTLSMode == "STARTTLS":
smtp = smtplib.SMTP(mailServer, mailPort, myHostname)
smtp.starttls()
Expand Down
2 changes: 1 addition & 1 deletion sydent/util/stringutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def is_valid_hostname(string: str) -> bool:
return hostname_regex.match(string) is not None


def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
def parse_server_name(server_name: str) -> Tuple[str, Optional[str]]:
"""Split a server name into host/port parts.

No validation is done on the host part. The port part is validated to be
Expand Down
2 changes: 1 addition & 1 deletion sydent/validators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def validateSessionWithToken(
sydent: "Sydent", sid: str, clientSecret: str, token: str
sydent: "Sydent", sid: int, clientSecret: str, token: str
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[str, bool]:
"""
Attempt to validate a session, identified by the sid, using
Expand Down
4 changes: 2 additions & 2 deletions sydent/validators/emailvalidator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,13 @@ def makeValidateLink(
return link

def validateSessionWithToken(
self, sid: str, clientSecret: str, token: str
self, sid: int, clientSecret: str, token: str
) -> Dict[str, bool]:
"""
Validates the session with the given ID.

:param sid: The ID of the session to validate.
:type sid: unicode
:type sid: int
:param clientSecret: The client secret to validate.
:type clientSecret: unicode
:param token: The token to validate.
Expand Down
4 changes: 2 additions & 2 deletions sydent/validators/msisdnvalidator.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,13 @@ def getOriginator(
return origs[sum(int(i) for i in msisdn) % len(origs)]

def validateSessionWithToken(
self, sid: str, clientSecret: str, token: str
self, sid: int, clientSecret: str, token: str
) -> Dict[str, bool]:
"""
Validates the session with the given ID.

:param sid: The ID of the session to validate.
:type sid: unicode
:type sid: int
:param clientSecret: The client secret to validate.
:type clientSecret: unicode
:param token: The token to validate.
Expand Down