From 733272ad23538fb1698878bd5b669977d8f22b82 Mon Sep 17 00:00:00 2001 From: Hillery Shay Date: Wed, 9 Jun 2021 08:50:19 -0700 Subject: [PATCH] Add type hints for checking by mypy (#355) Signed-off-by H-Shay: --- changelog.d/355.misc | 1 + sydent/db/accounts.py | 19 ++++++---- sydent/db/hashing_metadata.py | 21 ++++++++--- sydent/db/invite_tokens.py | 22 +++++++----- sydent/db/peers.py | 17 +++++---- sydent/db/sqlitedb.py | 6 +++- sydent/db/terms.py | 11 ++++-- sydent/db/threepid_associations.py | 53 +++++++++++++++++++--------- sydent/db/valsession.py | 33 +++++++++++------ sydent/hs_federation/verifier.py | 29 ++++++++++----- sydent/http/auth.py | 18 +++++++--- sydent/http/httpclient.py | 17 ++++++--- sydent/http/httpcommon.py | 7 +++- sydent/http/httpsclient.py | 13 ++++--- sydent/http/httpserver.py | 10 ++++-- sydent/http/matrixfederationagent.py | 28 ++++++++++----- sydent/http/srvresolver.py | 13 +++++-- sydent/replication/peer.py | 45 ++++++++++++++++------- sydent/replication/pusher.py | 12 ++++--- sydent/sms/openmarket.py | 16 ++++++--- sydent/terms/terms.py | 14 ++++---- sydent/threepid/bind.py | 18 ++++++---- sydent/threepid/signer.py | 12 +++++-- sydent/users/accounts.py | 2 +- sydent/users/tokens.py | 7 +++- sydent/util/emailutils.py | 10 +++++- sydent/util/hash.py | 4 +-- sydent/util/ip_range.py | 2 +- sydent/util/stringutils.py | 2 +- sydent/util/tokenutils.py | 6 ++-- sydent/util/ttlcache.py | 7 ++-- sydent/validators/common.py | 8 ++++- sydent/validators/emailvalidator.py | 29 +++++++++------ sydent/validators/msisdnvalidator.py | 26 ++++++++++---- 34 files changed, 380 insertions(+), 158 deletions(-) create mode 100644 changelog.d/355.misc diff --git a/changelog.d/355.misc b/changelog.d/355.misc new file mode 100644 index 00000000..b1c52a2b --- /dev/null +++ b/changelog.d/355.misc @@ -0,0 +1 @@ +Added type hints to support mypy checks. \ No newline at end of file diff --git a/sydent/db/accounts.py b/sydent/db/accounts.py index 52b0cb85..d761082d 100644 --- a/sydent/db/accounts.py +++ b/sydent/db/accounts.py @@ -15,14 +15,19 @@ # limitations under the License. from __future__ import absolute_import +from typing import TYPE_CHECKING, Optional + from sydent.users.accounts import Account +if TYPE_CHECKING: + from sydent.sydent import Sydent + class AccountStore(object): - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent - def getAccountByToken(self, token): + def getAccountByToken(self, token: str) -> Optional[Account]: """ Select the account matching the given token, if any. @@ -45,7 +50,9 @@ def getAccountByToken(self, token): return Account(*row) - def storeAccount(self, user_id, creation_ts, consent_version): + def storeAccount( + self, user_id: str, creation_ts: int, consent_version: Optional[str] + ) -> None: """ Stores an account for the given user ID. @@ -65,7 +72,7 @@ def storeAccount(self, user_id, creation_ts, consent_version): ) self.sydent.db.commit() - def setConsentVersion(self, user_id, consent_version): + def setConsentVersion(self, user_id: str, consent_version: Optional[str]) -> None: """ Saves that the given user has agreed to all of the terms in the document of the given version. @@ -82,7 +89,7 @@ def setConsentVersion(self, user_id, consent_version): ) self.sydent.db.commit() - def addToken(self, user_id, token): + def addToken(self, user_id: str, token: str) -> None: """ Stores the authentication token for a given user. @@ -98,7 +105,7 @@ def addToken(self, user_id, token): ) self.sydent.db.commit() - def delToken(self, token): + def delToken(self, token: str) -> int: """ Deletes an authentication token from the database. diff --git a/sydent/db/hashing_metadata.py b/sydent/db/hashing_metadata.py index 111224f4..0a70ef0b 100644 --- a/sydent/db/hashing_metadata.py +++ b/sydent/db/hashing_metadata.py @@ -16,13 +16,18 @@ # Actions on the hashing_metadata table which is defined in the migration process in # sqlitedb.py +from sqlite3 import Cursor +from typing import TYPE_CHECKING, Callable, Optional + +if TYPE_CHECKING: + from sydent.sydent import Sydent class HashingMetadataStore: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent - def get_lookup_pepper(self): + def get_lookup_pepper(self) -> Optional[str]: """Return the value of the current lookup pepper from the db :return: A pepper if it exists in the database, or None if one does @@ -44,7 +49,9 @@ def get_lookup_pepper(self): return pepper - def store_lookup_pepper(self, hashing_function, pepper): + def store_lookup_pepper( + self, hashing_function: Callable[[str], str], pepper: str + ) -> None: """Stores a new lookup pepper in the hashing_metadata db table and rehashes all 3PIDs :param hashing_function: A function with single input and output strings @@ -74,7 +81,13 @@ def store_lookup_pepper(self, hashing_function, pepper): # Commit the queued db transactions so that adding a new pepper and hashing is atomic self.sydent.db.commit() - def _rehash_threepids(self, cur, hashing_function, pepper, table): + def _rehash_threepids( + self, + cur: Cursor, + hashing_function: Callable[[str], str], + pepper: str, + table: str, + ) -> None: """Rehash 3PIDs of a given table using a given hashing_function and pepper A database cursor `cur` must be passed to this function. After this function completes, diff --git a/sydent/db/invite_tokens.py b/sydent/db/invite_tokens.py index 8b92b677..f419f286 100644 --- a/sydent/db/invite_tokens.py +++ b/sydent/db/invite_tokens.py @@ -14,13 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +from typing import TYPE_CHECKING, Dict, List, Optional + +if TYPE_CHECKING: + from sydent.sydent import Sydent class JoinTokenStore(object): - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent - def storeToken(self, medium, address, roomId, sender, token): + def storeToken( + self, medium: str, address: str, roomId: str, sender: str, token: str + ) -> None: """ Store a new invite token and its metadata. @@ -45,7 +51,7 @@ def storeToken(self, medium, address, roomId, sender, token): ) self.sydent.db.commit() - def getTokens(self, medium, address): + def getTokens(self, medium: str, address: str) -> List[Dict[str, str]]: """ Retrieves the pending invites tokens for this 3PID that haven't been delivered yet. @@ -100,7 +106,7 @@ def getTokens(self, medium, address): return ret - def markTokensAsSent(self, medium, address): + def markTokensAsSent(self, medium: str, address: str) -> None: """ Updates the invite tokens associated with a given 3PID to mark them as delivered to a homeserver so they're not delivered again in the future. @@ -122,7 +128,7 @@ def markTokensAsSent(self, medium, address): ) self.sydent.db.commit() - def storeEphemeralPublicKey(self, publicKey): + def storeEphemeralPublicKey(self, publicKey: str) -> None: """ Saves the provided ephemeral public key. @@ -138,7 +144,7 @@ def storeEphemeralPublicKey(self, publicKey): ) self.sydent.db.commit() - def validateEphemeralPublicKey(self, publicKey): + def validateEphemeralPublicKey(self, publicKey: str) -> bool: """ Checks if an ephemeral public key is valid, and, if it is, updates its verification count. @@ -159,7 +165,7 @@ def validateEphemeralPublicKey(self, publicKey): self.sydent.db.commit() return cur.rowcount > 0 - def getSenderForToken(self, token): + def getSenderForToken(self, token: str) -> Optional[str]: """ Retrieves the MXID of the user that sent the invite the provided token is for. @@ -177,7 +183,7 @@ def getSenderForToken(self, token): return rows[0][0] return None - def deleteTokens(self, medium, address): + def deleteTokens(self, medium: str, address: str) -> None: """ Deletes every token for a given 3PID. diff --git a/sydent/db/peers.py b/sydent/db/peers.py index 3c796ad6..64137983 100644 --- a/sydent/db/peers.py +++ b/sydent/db/peers.py @@ -15,14 +15,19 @@ # limitations under the License. from __future__ import absolute_import +from typing import TYPE_CHECKING, Dict, List, Optional + from sydent.replication.peer import RemotePeer +if TYPE_CHECKING: + from sydent.sydent import Sydent + class PeerStore: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent - def getPeerByName(self, name): + def getPeerByName(self, name: str) -> Optional[RemotePeer]: """ Retrieves a remote peer using it's server name. @@ -57,7 +62,7 @@ def getPeerByName(self, name): return p - def getAllPeers(self): + def getAllPeers(self) -> List[RemotePeer]: """ Retrieve all of the remote peers from the database. @@ -75,7 +80,7 @@ def getAllPeers(self): peername = None port = None lastSentVer = None - pubkeys = {} + pubkeys: Dict[str, str] = {} for row in res.fetchall(): if row[0] != peername: @@ -95,8 +100,8 @@ def getAllPeers(self): return peers def setLastSentVersionAndPokeSucceeded( - self, peerName, lastSentVersion, lastPokeSucceeded - ): + self, peerName: str, lastSentVersion: int, lastPokeSucceeded: int + ) -> None: """ Sets the ID of the last association sent to a given peer and the time of the last successful request sent to that peer. diff --git a/sydent/db/sqlitedb.py b/sydent/db/sqlitedb.py index 4037cbc3..9c2fa3f1 100644 --- a/sydent/db/sqlitedb.py +++ b/sydent/db/sqlitedb.py @@ -18,12 +18,16 @@ import logging import os import sqlite3 +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sydent.sydent import Sydent logger = logging.getLogger(__name__) class SqliteDatabase: - def __init__(self, syd): + def __init__(self, syd: "Sydent") -> None: self.sydent = syd dbFilePath = self.sydent.cfg.get("db", "db.file") diff --git a/sydent/db/terms.py b/sydent/db/terms.py index 6e27715a..99e15c42 100644 --- a/sydent/db/terms.py +++ b/sydent/db/terms.py @@ -14,12 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, List + +if TYPE_CHECKING: + from sydent.sydent import Sydent + class TermsStore(object): - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent - def getAgreedUrls(self, user_id): + def getAgreedUrls(self, user_id: str) -> List[str]: """ Retrieves the URLs of the terms the given user has agreed to. @@ -45,7 +50,7 @@ def getAgreedUrls(self, user_id): return urls - def addAgreedUrls(self, user_id, urls): + def addAgreedUrls(self, user_id: str, urls: List[str]) -> None: """ Saves that the given user has accepted the terms at the given URLs. diff --git a/sydent/db/threepid_associations.py b/sydent/db/threepid_associations.py index ee39b692..48227d5d 100644 --- a/sydent/db/threepid_associations.py +++ b/sydent/db/threepid_associations.py @@ -16,19 +16,23 @@ from __future__ import absolute_import import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from sydent.threepid import ThreepidAssociation from sydent.threepid.signer import Signer from sydent.util import time_msec +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) class LocalAssociationStore: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent - def addOrUpdateAssociation(self, assoc): + def addOrUpdateAssociation(self, assoc: ThreepidAssociation) -> None: """ Updates an association, or creates one if none exists with these parameters. @@ -54,7 +58,9 @@ def addOrUpdateAssociation(self, assoc): ) self.sydent.db.commit() - def getAssociationsAfterId(self, afterId, limit=None): + def getAssociationsAfterId( + self, afterId: Optional[int], limit: Optional[int] = None + ) -> Tuple[Dict[int, ThreepidAssociation], Optional[int]]: """ Retrieves every association after the given ID. @@ -97,7 +103,9 @@ def getAssociationsAfterId(self, afterId, limit=None): return assocs, maxId - def getSignedAssociationsAfterId(self, afterId, limit=None): + def getSignedAssociationsAfterId( + self, afterId: Optional[int], limit: Optional[int] = None + ) -> Tuple[Dict[int, Dict[str, Any]], Optional[int]]: """Get associations after a given ID, and sign them before returning :param afterId: The ID to return results after (not inclusive) @@ -124,7 +132,7 @@ def getSignedAssociationsAfterId(self, afterId, limit=None): return assocs, maxId - def removeAssociation(self, threepid, mxid): + def removeAssociation(self, threepid: Dict[str, str], mxid: str) -> None: """ Delete the association between a 3PID and a MXID, if it exists. If the association doesn't exist, log and do nothing. @@ -176,10 +184,12 @@ def removeAssociation(self, threepid, mxid): class GlobalAssociationStore: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent - def signedAssociationStringForThreepid(self, medium, address): + def signedAssociationStringForThreepid( + self, medium: str, address: str + ) -> Optional[str]: """ Retrieve the JSON for the signed association matching the provided 3PID, if one exists. @@ -214,7 +224,7 @@ def signedAssociationStringForThreepid(self, medium, address): return sgAssocStr - def getMxid(self, medium, address): + def getMxid(self, medium: str, address: str) -> Optional[str]: """ Retrieves the MXID associated with a 3PID. @@ -241,7 +251,9 @@ def getMxid(self, medium, address): return row[0] - def getMxids(self, threepid_tuples): + def getMxids( + self, threepid_tuples: List[Tuple[str, str]] + ) -> List[Tuple[str, str, str]]: """Given a list of threepid_tuples, return the same list but with mxids appended to each tuple for which a match was found in the database for. Output is ordered by medium, address, timestamp DESC @@ -281,7 +293,7 @@ def getMxids(self, threepid_tuples): ) results = [] - current = () + current = None for row in res.fetchall(): # only use the most recent entry for each # threepid (they're sorted by ts) @@ -295,7 +307,14 @@ def getMxids(self, threepid_tuples): return results - def addAssociation(self, assoc, rawSgAssoc, originServer, originId, commit=True): + def addAssociation( + self, + assoc: ThreepidAssociation, + rawSgAssoc: Dict[str, Any], + originServer: str, + originId: int, + commit: bool = True, + ) -> None: """ Saves an association received through either a replication push or a local push. @@ -333,7 +352,7 @@ def addAssociation(self, assoc, rawSgAssoc, originServer, originId, commit=True) if commit: self.sydent.db.commit() - def lastIdFromServer(self, server): + def lastIdFromServer(self, server: str) -> Optional[int]: """ Retrieves the ID of the last association received from the given peer. @@ -357,7 +376,7 @@ def lastIdFromServer(self, server): return row[0] - def removeAssociation(self, medium, address): + def removeAssociation(self, medium: str, address: str) -> None: """ Removes any association stored for the provided 3PID. @@ -380,7 +399,7 @@ def removeAssociation(self, medium, address): ) self.sydent.db.commit() - def retrieveMxidsForHashes(self, addresses): + def retrieveMxidsForHashes(self, addresses: List[str]) -> Dict[str, str]: """Returns a mapping from hash: mxid from a list of given lookup_hash values :param addresses: An array of lookup_hash values to check against the db @@ -403,14 +422,14 @@ def retrieveMxidsForHashes(self, addresses): results = {} try: # Convert list of addresses to list of tuples of addresses - addresses = [(x,) for x in addresses] + tuplized_addresses = [(x,) for x in addresses] inserted_cap = 0 - while inserted_cap < len(addresses): + while inserted_cap < len(tuplized_addresses): cur.executemany( "INSERT INTO tmp_retrieve_mxids_for_hashes(lookup_hash) " "VALUES (?)", - addresses[inserted_cap : inserted_cap + 500], + tuplized_addresses[inserted_cap : inserted_cap + 500], ) inserted_cap += 500 diff --git a/sydent/db/valsession.py b/sydent/db/valsession.py index f52d777b..aba15322 100644 --- a/sydent/db/valsession.py +++ b/sydent/db/valsession.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from random import SystemRandom +from typing import TYPE_CHECKING, Optional import sydent.util.tokenutils from sydent.util import time_msec @@ -27,13 +28,18 @@ ValidationSession, ) +if TYPE_CHECKING: + from sydent.sydent import Sydent + class ThreePidValSessionStore: - def __init__(self, syd): + def __init__(self, syd: "Sydent") -> None: self.sydent = syd self.random = SystemRandom() - def getOrCreateTokenSession(self, medium, address, clientSecret): + def getOrCreateTokenSession( + self, medium: str, address: str, clientSecret: str + ) -> ValidationSession: """ Retrieves the validation session for a given medium, address and client secret, or creates one if none was found. @@ -82,7 +88,14 @@ def getOrCreateTokenSession(self, medium, address, clientSecret): ) return s - def addValSession(self, medium, address, clientSecret, mtime, commit=True): + def addValSession( + self, + medium: str, + address: str, + clientSecret: str, + mtime: int, + commit: bool = True, + ) -> int: """ Creates a validation session with the given parameters. @@ -117,7 +130,7 @@ def addValSession(self, medium, address, clientSecret, mtime, commit=True): self.sydent.db.commit() return sid - def setSendAttemptNumber(self, sid, attemptNo): + def setSendAttemptNumber(self, sid: int, attemptNo: int) -> None: """ Updates the send attempt number for the session with the given ID. @@ -134,7 +147,7 @@ def setSendAttemptNumber(self, sid, attemptNo): ) self.sydent.db.commit() - def setValidated(self, sid, validated): + def setValidated(self, sid: int, validated: bool) -> None: """ Updates a session to set the validated flag to the given value. @@ -151,7 +164,7 @@ def setValidated(self, sid, validated): ) self.sydent.db.commit() - def setMtime(self, sid, mtime): + def setMtime(self, sid: int, mtime: int) -> None: """ Set the time of the last send attempt for the session with the given ID @@ -168,7 +181,7 @@ def setMtime(self, sid, mtime): ) self.sydent.db.commit() - def getSessionById(self, sid): + def getSessionById(self, sid: int) -> Optional[ValidationSession]: """ Retrieves the session matching the given sid. @@ -195,7 +208,7 @@ def getSessionById(self, sid): row[0], row[1], row[2], row[3], row[4], row[5], None, None ) - def getTokenSessionById(self, sid): + def getTokenSessionById(self, sid: int) -> Optional[ValidationSession]: """ Retrieves a validation session using the session's ID. @@ -223,7 +236,7 @@ def getTokenSessionById(self, sid): return None - def getValidatedSession(self, sid, clientSecret): + def getValidatedSession(self, sid: int, clientSecret: str) -> ValidationSession: """ Retrieve a validated and still-valid session whose client secret matches the one passed in. @@ -260,7 +273,7 @@ def getValidatedSession(self, sid, clientSecret): return s - def deleteOldSessions(self): + def deleteOldSessions(self) -> None: """Delete old threepid validation sessions that are long expired.""" cur = self.sydent.db.cursor() diff --git a/sydent/hs_federation/verifier.py b/sydent/hs_federation/verifier.py index 31481a66..2dfb0a28 100644 --- a/sydent/hs_federation/verifier.py +++ b/sydent/hs_federation/verifier.py @@ -17,16 +17,21 @@ import logging import time +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple -import signedjson.key -import signedjson.sign +import signedjson.key # type: ignore +import signedjson.sign # type: ignore from signedjson.sign import SignatureVerifyException from twisted.internet import defer -from unpaddedbase64 import decode_base64 +from twisted.web.server import Request +from unpaddedbase64 import decode_base64 # type: ignore from sydent.http.httpclient import FederationHttpClient from sydent.util.stringutils import is_valid_matrix_server_name +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) @@ -53,16 +58,16 @@ class Verifier(object): verifying that the signature on the json blob matches. """ - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent # Cache of server keys. These are cached until the 'valid_until_ts' time # in the result. - self.cache = { + self.cache: Dict[str, Any] = { # server_name: , } @defer.inlineCallbacks - def _getKeysForServer(self, server_name): + def _getKeysForServer(self, server_name: str) -> Generator: """Get the signing key data from a homeserver. :param server_name: The name of the server to request the keys from. @@ -105,7 +110,11 @@ def _getKeysForServer(self, server_name): defer.returnValue(result["verify_keys"]) @defer.inlineCallbacks - def verifyServerSignedJson(self, signed_json, acceptable_server_names=None): + def verifyServerSignedJson( + self, + signed_json: Dict[str, Any], + acceptable_server_names: Optional[List[str]] = None, + ) -> Generator: """Given a signed json object, try to verify any one of the signatures on it @@ -162,7 +171,9 @@ def verifyServerSignedJson(self, signed_json, acceptable_server_names=None): raise SignatureVerifyException("No matching signature found") @defer.inlineCallbacks - def authenticate_request(self, request, content): + def authenticate_request( + self, request: "Request", content: Optional[bytes] + ) -> Generator: """Authenticates a Matrix federation request based on the X-Matrix header XXX: Copied largely from synapse @@ -186,7 +197,7 @@ def authenticate_request(self, request, content): origin = None - def parse_auth_header(header_str): + def parse_auth_header(header_str: str) -> Tuple[str, str, str]: """ Extracts a server name, signing key and payload signature from an authentication header. diff --git a/sydent/http/auth.py b/sydent/http/auth.py index fde36bb3..d4e78066 100644 --- a/sydent/http/auth.py +++ b/sydent/http/auth.py @@ -16,15 +16,22 @@ from __future__ import absolute_import import logging +from typing import TYPE_CHECKING, Optional + +from twisted.web.server import Request from sydent.db.accounts import AccountStore from sydent.http.servlets import MatrixRestError, get_args from sydent.terms.terms import get_terms +if TYPE_CHECKING: + from sydent.db.accounts import Account + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) -def tokenFromRequest(request): +def tokenFromRequest(request: "Request") -> Optional[str]: """Extract token from header of query parameter. :param request: The request to look for an access token in. @@ -51,7 +58,11 @@ def tokenFromRequest(request): return token -def authV2(sydent, request, requireTermsAgreed=True): +def authV2( + sydent: "Sydent", + request: "Request", + requireTermsAgreed: bool = True, +) -> "Account": """For v2 APIs check that the request has a valid access token associated with it :param sydent: The Sydent instance to use. @@ -61,8 +72,7 @@ def authV2(sydent, request, requireTermsAgreed=True): :param requireTermsAgreed: Whether to deny authentication if the user hasn't accepted the terms of service. - :returns Account|None: The account object if there is correct auth, or None for v1 - APIs. + :returns Account: The account object if there is correct auth :raises MatrixRestError: If the request is v2 but could not be authed or the user has not accepted terms. """ diff --git a/sydent/http/httpclient.py b/sydent/http/httpclient.py index 54f4f1b7..bfdcf2bc 100644 --- a/sydent/http/httpclient.py +++ b/sydent/http/httpclient.py @@ -18,10 +18,12 @@ import json import logging from io import BytesIO +from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from twisted.internet import defer from twisted.web.client import Agent, FileBodyProducer from twisted.web.http_headers import Headers +from twisted.web.iweb import IAgent from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper from sydent.http.federation_tls_options import ClientTLSOptionsFactory @@ -29,6 +31,9 @@ from sydent.http.matrixfederationagent import MatrixFederationAgent from sydent.util import json_decoder +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) @@ -37,8 +42,10 @@ class HTTPClient(object): requests. """ + agent: IAgent + @defer.inlineCallbacks - def get_json(self, uri, max_size=None): + def get_json(self, uri: str, max_size: Optional[int] = None) -> Generator: """Make a GET request to an endpoint returning JSON and parse result :param uri: The URI to make a GET request to. @@ -66,7 +73,9 @@ def get_json(self, uri, max_size=None): defer.returnValue(json_body) @defer.inlineCallbacks - def post_json_get_nothing(self, uri, post_json, opts): + def post_json_get_nothing( + self, uri: str, post_json: Dict[Any, Any], opts: Dict[str, Any] + ) -> Generator: """Make a POST request to an endpoint returning JSON and parse result :param uri: The URI to make a POST request to. @@ -120,7 +129,7 @@ class SimpleHttpClient(HTTPClient): from Synapse. """ - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent # The default endpoint factory in Twisted 14.0.0 (which we require) uses the # BrowserLikePolicyForHTTPS context factory which will do regular cert validation @@ -140,7 +149,7 @@ class FederationHttpClient(HTTPClient): MatrixFederationAgent. """ - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent self.agent = MatrixFederationAgent( BlacklistingReactorWrapper( diff --git a/sydent/http/httpcommon.py b/sydent/http/httpcommon.py index 28fbd844..6d595d4b 100644 --- a/sydent/http/httpcommon.py +++ b/sydent/http/httpcommon.py @@ -16,6 +16,7 @@ import logging from io import BytesIO +from typing import TYPE_CHECKING import twisted.internet.ssl from twisted.internet import defer, protocol @@ -25,6 +26,10 @@ from twisted.web.http import PotentialDataLoss from twisted.web.iweb import UNKNOWN_LENGTH +if TYPE_CHECKING: + from sydent.sydent import Sydent + + logger = logging.getLogger(__name__) # Arbitrarily limited to 512 KiB. @@ -32,7 +37,7 @@ class SslComponents: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent self.myPrivateCertificate = self.makeMyCertificate() diff --git a/sydent/http/httpsclient.py b/sydent/http/httpsclient.py index 442b208c..942a34d3 100644 --- a/sydent/http/httpsclient.py +++ b/sydent/http/httpsclient.py @@ -18,13 +18,18 @@ import json import logging from io import BytesIO +from typing import TYPE_CHECKING, Any, Dict, Optional +from twisted.internet.defer import Deferred from twisted.internet.ssl import optionsForClientTLS from twisted.web.client import Agent, FileBodyProducer from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS from zope.interface import implementer +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) @@ -35,7 +40,7 @@ class ReplicationHttpsClient: replication HTTPS server) """ - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent self.agent = None @@ -47,7 +52,7 @@ def __init__(self, sydent): # trustRoot=self.sydent.sslComponents.trustRoot) self.agent = Agent(self.sydent.reactor, SydentPolicyForHTTPS(self.sydent)) - def postJson(self, uri, jsonObject): + def postJson(self, uri: str, jsonObject: Dict[Any, Any]) -> Optional[Deferred]: """ Sends an POST request over HTTPS. @@ -62,7 +67,7 @@ def postJson(self, uri, jsonObject): logger.debug("POSTing request to %s", uri) if not self.agent: logger.error("HTTPS post attempted but HTTPS is not configured") - return + return None headers = Headers( {"Content-Type": ["application/json"], "User-Agent": ["Sydent"]} @@ -78,7 +83,7 @@ def postJson(self, uri, jsonObject): @implementer(IPolicyForHTTPS) class SydentPolicyForHTTPS(object): - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent def creatorForNetloc(self, hostname, port): diff --git a/sydent/http/httpserver.py b/sydent/http/httpserver.py index 6ee7d72b..1b4e3bb0 100644 --- a/sydent/http/httpserver.py +++ b/sydent/http/httpserver.py @@ -18,6 +18,7 @@ from __future__ import absolute_import import logging +from typing import TYPE_CHECKING import twisted.internet.ssl from twisted.web.resource import Resource @@ -31,11 +32,14 @@ AuthenticatedUnbindThreePidServlet, ) +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) class ClientApiHttpServer: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent root = Resource() @@ -149,7 +153,7 @@ def setup(self): class InternalApiHttpServer(object): - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent def setup(self, interface, port): @@ -177,7 +181,7 @@ def setup(self, interface, port): class ReplicationHttpsServer: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent root = Resource() diff --git a/sydent/http/matrixfederationagent.py b/sydent/http/matrixfederationagent.py index 5662fb67..45540012 100644 --- a/sydent/http/matrixfederationagent.py +++ b/sydent/http/matrixfederationagent.py @@ -17,16 +17,18 @@ import logging import random import time +from typing import Generator, Optional import attr -from netaddr import IPAddress +from netaddr import IPAddress # type: ignore from twisted.internet import defer +from twisted.internet.defer import Deferred from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import IStreamClientEndpoint from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent +from twisted.web.iweb import IAgent, IBodyProducer from zope.interface import implementer from sydent.http.httpcommon import read_body_with_max_size @@ -84,9 +86,9 @@ def __init__( reactor, tls_client_options_factory, _well_known_tls_policy=None, - _srv_resolver=None, - _well_known_cache=well_known_cache, - ): + _srv_resolver: Optional["SrvResolver"] = None, + _well_known_cache: Optional["TTLCache"] = well_known_cache, + ) -> None: self._reactor = reactor self._tls_client_options_factory = tls_client_options_factory @@ -116,7 +118,13 @@ def __init__( self._well_known_cache = _well_known_cache @defer.inlineCallbacks - def request(self, method, uri, headers=None, bodyProducer=None): + def request( + self, + method: bytes, + uri: bytes, + headers: Optional["Headers"] = None, + bodyProducer: Optional["IBodyProducer"] = None, + ) -> Generator: """ :param method: HTTP method (GET/POST/etc). :type method: bytes @@ -181,7 +189,9 @@ def endpointForURI(_uri): defer.returnValue(res) @defer.inlineCallbacks - def _route_matrix_uri(self, parsed_uri, lookup_well_known=True): + def _route_matrix_uri( + self, parsed_uri: "URI", lookup_well_known: bool = True + ) -> "Deferred": """Helper for `request`: determine the routing for a Matrix URI :param parsed_uri: uri to route. Note that it should be parsed with @@ -294,7 +304,7 @@ def _route_matrix_uri(self, parsed_uri, lookup_well_known=True): ) @defer.inlineCallbacks - def _get_well_known(self, server_name): + def _get_well_known(self, server_name: bytes) -> Generator: """Attempt to fetch and parse a .well-known file for the given server :param server_name: Name of the server, from the requested url. @@ -317,7 +327,7 @@ def _get_well_known(self, server_name): defer.returnValue(result) @defer.inlineCallbacks - def _do_get_well_known(self, server_name): + def _do_get_well_known(self, server_name: bytes) -> Generator: """Actually fetch and parse a .well-known, without checking the cache :param server_name: Name of the server, from the requested url diff --git a/sydent/http/srvresolver.py b/sydent/http/srvresolver.py index 586bd85f..db4a72e8 100644 --- a/sydent/http/srvresolver.py +++ b/sydent/http/srvresolver.py @@ -17,10 +17,12 @@ import logging import random import time +from typing import Callable, Dict, Generator, List, SupportsInt, Tuple import attr from twisted.internet import defer from twisted.internet.error import ConnectError +from twisted.internet.interfaces import IResolver from twisted.names import client, dns from twisted.names.error import DNSNameError, DomainError @@ -49,7 +51,7 @@ class Server(object): expires = attr.ib(default=0) -def pick_server_from_list(server_list): +def pick_server_from_list(server_list: List[Server]) -> Tuple[bytes, int]: """Randomly choose a server from the server list. :param server_list: List of candidate servers. @@ -96,13 +98,18 @@ class SrvResolver(object): :type get_time: callable """ - def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): + def __init__( + self, + dns_client: "IResolver" = client, + cache: Dict[bytes, List[Server]] = SERVER_CACHE, + get_time: Callable[[], SupportsInt] = time.time, + ) -> None: self._dns_client = dns_client self._cache = cache self._get_time = get_time @defer.inlineCallbacks - def resolve_service(self, service_name): + def resolve_service(self, service_name: bytes) -> "Generator": """Look up a SRV record :param service_name: The record to look up. diff --git a/sydent/replication/peer.py b/sydent/replication/peer.py index c969d224..827afab8 100644 --- a/sydent/replication/peer.py +++ b/sydent/replication/peer.py @@ -19,13 +19,16 @@ import binascii import json import logging +from typing import TYPE_CHECKING, Any, Dict -import signedjson.key -import signedjson.sign +import signedjson.key # type: ignore +import signedjson.sign # type: ignore from six.moves import configparser from twisted.internet import defer +from twisted.internet.defer import Deferred from twisted.web.client import readBody -from unpaddedbase64 import decode_base64 +from twisted.web.iweb import IResponse +from unpaddedbase64 import decode_base64 # type: ignore from sydent.config import ConfigError from sydent.db.hashing_metadata import HashingMetadataStore @@ -34,6 +37,9 @@ from sydent.util import json_decoder from sydent.util.hash import sha256_and_url_safe_base64 +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) SIGNING_KEY_ALGORITHM = "ed25519" @@ -45,7 +51,7 @@ def __init__(self, servername, pubkeys): self.pubkeys = pubkeys self.is_being_pushed_to = False - def pushUpdates(self, sgAssocs): + def pushUpdates(self, sgAssocs) -> "Deferred": """ :param sgAssocs: Sequence of (originId, sgAssoc) tuples where originId is the id on the creating server and sgAssoc is the json object of the signed association @@ -59,7 +65,7 @@ class LocalPeer(Peer): The local peer (ourselves: essentially copying from the local associations table to the global one) """ - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: super(LocalPeer, self).__init__(sydent.server_name, {}) self.sydent = sydent self.hashing_store = HashingMetadataStore(sydent) @@ -69,7 +75,7 @@ def __init__(self, sydent): if self.lastId is None: self.lastId = -1 - def pushUpdates(self, sgAssocs): + def pushUpdates(self, sgAssocs: Dict[int, Dict[str, Any]]) -> "Deferred": """ Saves the given associations in the global associations store. Only stores an association if its ID is greater than the last seen ID. @@ -114,7 +120,14 @@ def pushUpdates(self, sgAssocs): class RemotePeer(Peer): - def __init__(self, sydent, server_name, port, pubkeys, lastSentVersion): + def __init__( + self, + sydent: "Sydent", + server_name: str, + port: int, + pubkeys: Dict[str, str], + lastSentVersion: int, + ) -> None: """ :param sydent: The current Sydent instance. :type sydent: sydent.sydent.Sydent @@ -181,7 +194,7 @@ def __init__(self, sydent, server_name, port, pubkeys, lastSentVersion): self.verify_key.alg = SIGNING_KEY_ALGORITHM self.verify_key.version = 0 - def verifySignedAssociation(self, assoc): + def verifySignedAssociation(self, assoc: Dict[Any, Any]) -> None: """Verifies a signature on a signed association. Raises an exception if the signature is incorrect or couldn't be verified. @@ -205,7 +218,7 @@ def verifySignedAssociation(self, assoc): # Verify the JSON signedjson.sign.verify_signed_json(assoc, self.servername, self.verify_key) - def pushUpdates(self, sgAssocs): + def pushUpdates(self, sgAssocs: Dict[int, Dict[str, Any]]) -> "Deferred": """ Pushes the given associations to the peer. @@ -233,7 +246,11 @@ def pushUpdates(self, sgAssocs): return updateDeferred - def _pushSuccess(self, result, updateDeferred): + def _pushSuccess( + self, + result: "IResponse", + updateDeferred: "Deferred", + ) -> None: """ Processes a successful push request. If the request resulted in a status code that's not a success, consider it a failure @@ -251,7 +268,7 @@ def _pushSuccess(self, result, updateDeferred): d.addCallback(self._failedPushBodyRead, updateDeferred=updateDeferred) d.addErrback(self._pushFailed, updateDeferred=updateDeferred) - def _failedPushBodyRead(self, body, updateDeferred): + def _failedPushBodyRead(self, body: bytes, updateDeferred: "Deferred") -> None: """ Processes a response body from a failed push request, then calls the error callback of the provided deferred. @@ -266,7 +283,11 @@ def _failedPushBodyRead(self, body, updateDeferred): e.errorDict = errObj updateDeferred.errback(e) - def _pushFailed(self, failure, updateDeferred): + def _pushFailed( + self, + failure, + updateDeferred: "Deferred", + ) -> None: """ Processes a failed push request, by calling the error callback of the given deferred with it. diff --git a/sydent/replication/pusher.py b/sydent/replication/pusher.py index adec54b8..af8ab0b4 100644 --- a/sydent/replication/pusher.py +++ b/sydent/replication/pusher.py @@ -17,6 +17,7 @@ from __future__ import absolute_import import logging +from typing import TYPE_CHECKING, Generator import twisted.internet.reactor import twisted.internet.task @@ -24,9 +25,12 @@ from sydent.db.peers import PeerStore from sydent.db.threepid_associations import LocalAssociationStore -from sydent.replication.peer import LocalPeer +from sydent.replication.peer import LocalPeer, RemotePeer from sydent.util import time_msec +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) # Maximum amount of signed associations to replicate to a peer at a time @@ -34,7 +38,7 @@ class Pusher: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent self.pushing = False self.peerStore = PeerStore(self.sydent) @@ -45,7 +49,7 @@ def setup(self): cb.clock = self.sydent.reactor cb.start(10.0) - def doLocalPush(self): + def doLocalPush(self) -> None: """ Synchronously push local associations to this server (ie. copy them to globals table) The local server is essentially treated the same as any other peer except we don't do @@ -74,7 +78,7 @@ def scheduledPush(self): return defer.DeferredList([self._push_to_peer(p) for p in peers]) @defer.inlineCallbacks - def _push_to_peer(self, p): + def _push_to_peer(self, p: "RemotePeer") -> Generator: """ For a given peer, retrieves the list of associations that were created since the last successful push to this peer (limited to ASSOCIATIONS_PUSH_LIMIT) and diff --git a/sydent/sms/openmarket.py b/sydent/sms/openmarket.py index 8af372f1..c7b11b88 100644 --- a/sydent/sms/openmarket.py +++ b/sydent/sms/openmarket.py @@ -17,12 +17,16 @@ import logging from base64 import b64encode +from typing import TYPE_CHECKING, Dict, Optional from twisted.internet import defer from twisted.web.http_headers import Headers from sydent.http.httpclient import SimpleHttpClient +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) @@ -40,7 +44,7 @@ } -def tonFromType(t): +def tonFromType(t: str) -> int: """ Get the type of number from the originator's type. @@ -56,12 +60,14 @@ def tonFromType(t): class OpenMarketSMS: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent self.http_cli = SimpleHttpClient(sydent) @defer.inlineCallbacks - def sendTextSMS(self, body, dest, source=None): + def sendTextSMS( + self, body: Dict, dest: str, source: Optional[Dict[str, str]] = None + ) -> None: """ Sends a text message with the given body to the given MSISDN. @@ -91,7 +97,7 @@ def sendTextSMS(self, body, dest, source=None): password = self.sydent.cfg.get("sms", "password").encode("UTF-8") b64creds = b64encode(b"%s:%s" % (username, password)) - headers = Headers( + req_headers = Headers( { b"Authorization": [b"Basic " + b64creds], b"Content-Type": [b"application/json"], @@ -99,7 +105,7 @@ def sendTextSMS(self, body, dest, source=None): ) resp = yield self.http_cli.post_json_get_nothing( - API_BASE_URL, body, {"headers": headers} + API_BASE_URL, body, {"headers": req_headers} ) headers = dict(resp.headers.getAllRawHeaders()) diff --git a/sydent/terms/terms.py b/sydent/terms/terms.py index e7361aad..0568a2b8 100644 --- a/sydent/terms/terms.py +++ b/sydent/terms/terms.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import Any, Dict, List, Optional, Set import yaml @@ -22,14 +23,14 @@ class Terms(object): - def __init__(self, yamlObj): + def __init__(self, yamlObj: Optional[Dict[str, Any]]) -> None: """ :param yamlObj: The parsed YAML. :type yamlObj: dict[str, any] or None """ self._rawTerms = yamlObj - def getMasterVersion(self): + def getMasterVersion(self) -> Optional[str]: """ :return: The global (master) version of the terms, or None if there are no terms of service for this server. @@ -43,7 +44,7 @@ def getMasterVersion(self): return version - def getForClient(self): + def getForClient(self) -> Dict[str, dict]: """ :return: A dict which value for the "policies" key is a dict which contains the "docs" part of the terms' YAML. That nested dict is empty if no terms. @@ -58,7 +59,7 @@ def getForClient(self): policies[docName].update(doc["langs"]) return {"policies": policies} - def getUrlSet(self): + def getUrlSet(self) -> Set[str]: """ :return: All the URLs for the terms in a set. Empty set if no terms. :rtype: set[unicode] @@ -76,7 +77,7 @@ def getUrlSet(self): urls.add(url) return urls - def urlListIsSufficient(self, urls): + def urlListIsSufficient(self, urls: List[str]) -> bool: """ Checks whether the provided list of URLs (which represents the list of terms accepted by the user) is enough to allow the creation of the user's account. @@ -102,7 +103,7 @@ def urlListIsSufficient(self, urls): return agreed == required -def get_terms(sydent): +def get_terms(sydent) -> Optional[Terms]: """Read and parse terms as specified in the config. :returns Terms @@ -139,3 +140,4 @@ def get_terms(sydent): logger.exception( "Couldn't read terms file '%s'", sydent.cfg.get("general", "terms.path") ) + return None diff --git a/sydent/threepid/bind.py b/sydent/threepid/bind.py index 641b1c3b..51dee029 100644 --- a/sydent/threepid/bind.py +++ b/sydent/threepid/bind.py @@ -19,8 +19,9 @@ import collections import logging import math +from typing import TYPE_CHECKING, Any, Dict, Union -import signedjson.sign +import signedjson.sign # type: ignore from twisted.internet import defer from sydent.db.hashing_metadata import HashingMetadataStore @@ -33,6 +34,9 @@ from sydent.util.hash import sha256_and_url_safe_base64 from sydent.util.stringutils import is_valid_matrix_server_name +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) @@ -40,11 +44,11 @@ class ThreepidBinder: # the lifetime of a 3pid association THREEPID_ASSOCIATION_LIFETIME_MS = 100 * 365 * 24 * 60 * 60 * 1000 - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent self.hashing_store = HashingMetadataStore(sydent) - def addBinding(self, medium, address, mxid): + def addBinding(self, medium: str, address: str, mxid: str) -> Dict[str, Any]: """ Binds the given 3pid to the given mxid. @@ -112,7 +116,7 @@ def addBinding(self, medium, address, mxid): return sgassoc - def removeBinding(self, threepid, mxid): + def removeBinding(self, threepid: Dict[str, str], mxid: str) -> None: """ Removes the binding between a given 3PID and a given MXID. @@ -126,7 +130,7 @@ def removeBinding(self, threepid, mxid): self.sydent.pusher.doLocalPush() @defer.inlineCallbacks - def _notify(self, assoc, attempt): + def _notify(self, assoc: Dict[str, Any], attempt: int) -> None: """ Sends data about a new association (and, if necessary, the associated invites) to the associated MXID's homeserver. @@ -193,7 +197,9 @@ def _notify(self, assoc, attempt): assoc["address"], ) - def _notifyErrback(self, assoc, attempt, error): + def _notifyErrback( + self, assoc: Dict[str, Any], attempt: int, error: Union[Exception, str] + ) -> None: """ Handles errors when trying to send an association down to a homeserver by logging the error and scheduling a new attempt. diff --git a/sydent/threepid/signer.py b/sydent/threepid/signer.py index 848217c9..1f03cfe1 100644 --- a/sydent/threepid/signer.py +++ b/sydent/threepid/signer.py @@ -14,14 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import signedjson.sign +from typing import TYPE_CHECKING, Any, Dict + +import signedjson.sign # type: ignore + +if TYPE_CHECKING: + from sydent.sydent import Sydent + from sydent.threepid import ThreepidAssociation class Signer: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent - def signedThreePidAssociation(self, assoc): + def signedThreePidAssociation(self, assoc: "ThreepidAssociation") -> Dict[str, Any]: """ Signs a 3PID association. diff --git a/sydent/users/accounts.py b/sydent/users/accounts.py index 4bef44b0..c022230a 100644 --- a/sydent/users/accounts.py +++ b/sydent/users/accounts.py @@ -15,7 +15,7 @@ class Account(object): - def __init__(self, user_id, creation_ts, consent_version): + def __init__(self, user_id: str, creation_ts: int, consent_version: str) -> None: """ :param user_id: The Matrix user ID for the account. :type user_id: str diff --git a/sydent/users/tokens.py b/sydent/users/tokens.py index 30c837a4..5e9cc53f 100644 --- a/sydent/users/tokens.py +++ b/sydent/users/tokens.py @@ -16,14 +16,19 @@ import logging import time +from typing import TYPE_CHECKING from sydent.db.accounts import AccountStore from sydent.util.tokenutils import generateAlphanumericTokenOfLength +if TYPE_CHECKING: + from sydent.sydent import Sydent + + logger = logging.getLogger(__name__) -def issueToken(sydent, user_id): +def issueToken(sydent: "Sydent", user_id: str) -> str: """ Creates an account for the given Matrix user ID, then generates, saves and returns an access token for that account. diff --git a/sydent/util/emailutils.py b/sydent/util/emailutils.py index 9d4c789e..844c5772 100644 --- a/sydent/util/emailutils.py +++ b/sydent/util/emailutils.py @@ -31,13 +31,20 @@ else: from html import escape +from typing import TYPE_CHECKING, Any, Dict + from sydent.util import time_msec from sydent.util.tokenutils import generateAlphanumericTokenOfLength +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) -def sendEmail(sydent, templateFile, mailTo, substitutions): +def sendEmail( + sydent: "Sydent", templateFile: str, mailTo: str, substitutions: Dict[str, str] +) -> None: """ Sends an email with the given parameters. @@ -131,4 +138,5 @@ class EmailAddressException(Exception): class EmailSendException(Exception): + cause: Any # type hint added to prevent ""EmailSendException" has no attribute "cause"" error in Mypy pass diff --git a/sydent/util/hash.py b/sydent/util/hash.py index 799094e8..70df3a4e 100644 --- a/sydent/util/hash.py +++ b/sydent/util/hash.py @@ -16,10 +16,10 @@ import hashlib -import unpaddedbase64 +import unpaddedbase64 # type: ignore -def sha256_and_url_safe_base64(input_text): +def sha256_and_url_safe_base64(input_text: str) -> str: """SHA256 hash an input string, encode the digest as url-safe base64, and return diff --git a/sydent/util/ip_range.py b/sydent/util/ip_range.py index 9e81a63e..f747b2c6 100644 --- a/sydent/util/ip_range.py +++ b/sydent/util/ip_range.py @@ -15,7 +15,7 @@ import itertools from typing import Iterable, Optional -from netaddr import AddrFormatError, IPNetwork, IPSet +from netaddr import AddrFormatError, IPNetwork, IPSet # type: ignore # IP ranges that are considered private / unroutable / don't make sense. DEFAULT_IP_RANGE_BLACKLIST = [ diff --git a/sydent/util/stringutils.py b/sydent/util/stringutils.py index 3b9f66fe..f526a698 100644 --- a/sydent/util/stringutils.py +++ b/sydent/util/stringutils.py @@ -37,7 +37,7 @@ MAX_EMAIL_ADDRESS_LENGTH = 500 -def is_valid_client_secret(client_secret): +def is_valid_client_secret(client_secret: str) -> bool: """Validate that a given string matches the client_secret regex defined by the spec :param client_secret: The client_secret to validate diff --git a/sydent/util/tokenutils.py b/sydent/util/tokenutils.py index 8800f708..4e8ec3f4 100644 --- a/sydent/util/tokenutils.py +++ b/sydent/util/tokenutils.py @@ -20,7 +20,7 @@ r = random.SystemRandom() -def generateTokenForMedium(medium): +def generateTokenForMedium(medium: str) -> str: """ Generates a token of a different format depending on the medium, a 32 characters alphanumeric one if the medium is email, a 6 characters numeric one otherwise. @@ -37,7 +37,7 @@ def generateTokenForMedium(medium): return generateNumericTokenOfLength(6) -def generateNumericTokenOfLength(length): +def generateNumericTokenOfLength(length: int) -> str: """ Generates a token of the given length with the character set [0-9]. @@ -50,7 +50,7 @@ def generateNumericTokenOfLength(length): return u"".join([r.choice(string.digits) for _ in range(length)]) -def generateAlphanumericTokenOfLength(length): +def generateAlphanumericTokenOfLength(length: int) -> str: """ Generates a token of the given length with the character set [a-zA-Z0-9]. diff --git a/sydent/util/ttlcache.py b/sydent/util/ttlcache.py index 8d98c301..702ef618 100644 --- a/sydent/util/ttlcache.py +++ b/sydent/util/ttlcache.py @@ -15,9 +15,10 @@ import logging import time +from typing import Any, Tuple import attr -from sortedcontainers import SortedList +from sortedcontainers import SortedList # type: ignore logger = logging.getLogger(__name__) @@ -36,7 +37,7 @@ def __init__(self, cache_name, timer=time.time): self._timer = timer - def set(self, key, value, ttl): + def set(self, key, value, ttl: float) -> None: """Add/update an entry in the cache :param key: Key for this entry. @@ -74,7 +75,7 @@ def get(self, key, default=SENTINEL): return default return e.value - def get_with_expiry(self, key): + def get_with_expiry(self, key) -> Tuple[Any, float]: """Get a value, and its expiry time, from the cache :param key: key to look up diff --git a/sydent/validators/common.py b/sydent/validators/common.py index 787eccde..31efa217 100644 --- a/sydent/validators/common.py +++ b/sydent/validators/common.py @@ -1,6 +1,7 @@ from __future__ import absolute_import import logging +from typing import TYPE_CHECKING, Dict from sydent.db.valsession import ThreePidValSessionStore from sydent.util import time_msec @@ -12,10 +13,15 @@ ValidationSession, ) +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) -def validateSessionWithToken(sydent, sid, clientSecret, token): +def validateSessionWithToken( + sydent: "Sydent", sid: str, clientSecret: str, token: str +) -> Dict[str, bool]: """ Attempt to validate a session, identified by the sid, using the token from out-of-band. The client secret is given to diff --git a/sydent/validators/emailvalidator.py b/sydent/validators/emailvalidator.py index 959af527..d91b8e55 100644 --- a/sydent/validators/emailvalidator.py +++ b/sydent/validators/emailvalidator.py @@ -16,6 +16,7 @@ from __future__ import absolute_import import logging +from typing import TYPE_CHECKING, Dict, Optional from six.moves import urllib @@ -24,22 +25,26 @@ from sydent.util.emailutils import sendEmail from sydent.validators import common +if TYPE_CHECKING: + from sydent.sydent import Sydent + from sydent.validators import ValidationSession + logger = logging.getLogger(__name__) class EmailValidator: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent def requestToken( self, - emailAddress, - clientSecret, - sendAttempt, - nextLink, - ipaddress=None, - brand=None, - ): + emailAddress: str, + clientSecret: str, + sendAttempt: int, + nextLink: str, + ipaddress: Optional[str] = None, + brand: Optional[str] = None, + ) -> int: """ Creates or retrieves a validation session and sends an email to the corresponding email address with a token to use to verify the association. @@ -102,7 +107,9 @@ def requestToken( return valSession.id - def makeValidateLink(self, valSession, clientSecret, nextLink): + def makeValidateLink( + self, valSession: "ValidationSession", clientSecret: str, nextLink: str + ) -> str: """ Creates a validation link that can be sent via email to the user. @@ -137,7 +144,9 @@ def makeValidateLink(self, valSession, clientSecret, nextLink): link += "&nextLink=%s" % (urllib.parse.quote(nextLink)) return link - def validateSessionWithToken(self, sid, clientSecret, token): + def validateSessionWithToken( + self, sid: str, clientSecret: str, token: str + ) -> Dict[str, bool]: """ Validates the session with the given ID. diff --git a/sydent/validators/msisdnvalidator.py b/sydent/validators/msisdnvalidator.py index ada2cbdb..c6ba57ac 100644 --- a/sydent/validators/msisdnvalidator.py +++ b/sydent/validators/msisdnvalidator.py @@ -17,24 +17,28 @@ from __future__ import absolute_import import logging +from typing import TYPE_CHECKING, Dict, List, Optional -import phonenumbers +import phonenumbers # type: ignore from sydent.db.valsession import ThreePidValSessionStore from sydent.sms.openmarket import OpenMarketSMS from sydent.util import time_msec from sydent.validators import DestinationRejectedException, common +if TYPE_CHECKING: + from sydent.sydent import Sydent + logger = logging.getLogger(__name__) class MsisdnValidator: - def __init__(self, sydent): + def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent self.omSms = OpenMarketSMS(sydent) # cache originators & sms rules from config file - self.originators = {} + self.originators: Dict[str, List[Dict[str, str]]] = {} self.smsRules = {} for opt in self.sydent.cfg.options("sms"): if opt.startswith("originators."): @@ -71,7 +75,13 @@ def __init__(self, sydent): self.smsRules[country] = action - def requestToken(self, phoneNumber, clientSecret, sendAttempt, brand=None): + def requestToken( + self, + phoneNumber: phonenumbers.PhoneNumber, + clientSecret: str, + sendAttempt: int, + brand: Optional[str] = None, + ) -> int: """ Creates or retrieves a validation session and sends an text message to the corresponding phone number address with a token to use to verify the association. @@ -132,7 +142,9 @@ def requestToken(self, phoneNumber, clientSecret, sendAttempt, brand=None): return valSession.id - def getOriginator(self, destPhoneNumber): + def getOriginator( + self, destPhoneNumber: phonenumbers.PhoneNumber + ) -> Dict[str, str]: """ Gets an originator for a given phone number. @@ -165,7 +177,9 @@ def getOriginator(self, destPhoneNumber): )[1:] return origs[sum([int(i) for i in msisdn]) % len(origs)] - def validateSessionWithToken(self, sid, clientSecret, token): + def validateSessionWithToken( + self, sid: str, clientSecret: str, token: str + ) -> Dict[str, bool]: """ Validates the session with the given ID.