Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints for checking by mypy #355

Merged
merged 14 commits into from
Jun 9, 2021
1 change: 1 addition & 0 deletions changelog.d/355.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added type hints to support mypy checks.
19 changes: 13 additions & 6 deletions sydent/db/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
# limitations under the License.
from __future__ import absolute_import

from typing import TYPE_CHECKING, Optional

from sydent.users.accounts import Account

if TYPE_CHECKING:
from sydent.sydent import Sydent


class AccountStore(object):
def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
self.sydent = sydent

def getAccountByToken(self, token):
def getAccountByToken(self, token: str) -> Optional[Account]:
"""
Select the account matching the given token, if any.

Expand All @@ -45,7 +50,9 @@ def getAccountByToken(self, token):

return Account(*row)

def storeAccount(self, user_id, creation_ts, consent_version):
def storeAccount(
self, user_id: str, creation_ts: int, consent_version: Optional[str]
) -> None:
"""
Stores an account for the given user ID.

Expand All @@ -65,7 +72,7 @@ def storeAccount(self, user_id, creation_ts, consent_version):
)
self.sydent.db.commit()

def setConsentVersion(self, user_id, consent_version):
def setConsentVersion(self, user_id: str, consent_version: Optional[str]) -> None:
"""
Saves that the given user has agreed to all of the terms in the document of the
given version.
Expand All @@ -82,7 +89,7 @@ def setConsentVersion(self, user_id, consent_version):
)
self.sydent.db.commit()

def addToken(self, user_id, token):
def addToken(self, user_id: str, token: str) -> None:
"""
Stores the authentication token for a given user.

Expand All @@ -98,7 +105,7 @@ def addToken(self, user_id, token):
)
self.sydent.db.commit()

def delToken(self, token):
def delToken(self, token: str) -> int:
"""
Deletes an authentication token from the database.

Expand Down
21 changes: 17 additions & 4 deletions sydent/db/hashing_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

# Actions on the hashing_metadata table which is defined in the migration process in
# sqlitedb.py
from sqlite3 import Cursor
from typing import TYPE_CHECKING, Callable, Optional

if TYPE_CHECKING:
from sydent.sydent import Sydent


class HashingMetadataStore:
def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
self.sydent = sydent

def get_lookup_pepper(self):
def get_lookup_pepper(self) -> Optional[str]:
"""Return the value of the current lookup pepper from the db

:return: A pepper if it exists in the database, or None if one does
Expand All @@ -44,7 +49,9 @@ def get_lookup_pepper(self):

return pepper

def store_lookup_pepper(self, hashing_function, pepper):
def store_lookup_pepper(
self, hashing_function: Callable[[str], str], pepper: str
) -> None:
"""Stores a new lookup pepper in the hashing_metadata db table and rehashes all 3PIDs

:param hashing_function: A function with single input and output strings
Expand Down Expand Up @@ -74,7 +81,13 @@ def store_lookup_pepper(self, hashing_function, pepper):
# Commit the queued db transactions so that adding a new pepper and hashing is atomic
self.sydent.db.commit()

def _rehash_threepids(self, cur, hashing_function, pepper, table):
def _rehash_threepids(
self,
cur: Cursor,
hashing_function: Callable[[str], str],
pepper: str,
table: str,
) -> None:
"""Rehash 3PIDs of a given table using a given hashing_function and pepper

A database cursor `cur` must be passed to this function. After this function completes,
Expand Down
22 changes: 14 additions & 8 deletions sydent/db/invite_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import TYPE_CHECKING, Dict, List, Optional

if TYPE_CHECKING:
from sydent.sydent import Sydent


class JoinTokenStore(object):
def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
self.sydent = sydent

def storeToken(self, medium, address, roomId, sender, token):
def storeToken(
self, medium: str, address: str, roomId: str, sender: str, token: str
) -> None:
"""
Store a new invite token and its metadata.

Expand All @@ -45,7 +51,7 @@ def storeToken(self, medium, address, roomId, sender, token):
)
self.sydent.db.commit()

def getTokens(self, medium, address):
def getTokens(self, medium: str, address: str) -> List[Dict[str, str]]:
"""
Retrieves the pending invites tokens for this 3PID that haven't been delivered
yet.
Expand Down Expand Up @@ -100,7 +106,7 @@ def getTokens(self, medium, address):

return ret

def markTokensAsSent(self, medium, address):
def markTokensAsSent(self, medium: str, address: str) -> None:
"""
Updates the invite tokens associated with a given 3PID to mark them as
delivered to a homeserver so they're not delivered again in the future.
Expand All @@ -122,7 +128,7 @@ def markTokensAsSent(self, medium, address):
)
self.sydent.db.commit()

def storeEphemeralPublicKey(self, publicKey):
def storeEphemeralPublicKey(self, publicKey: str) -> None:
"""
Saves the provided ephemeral public key.

Expand All @@ -138,7 +144,7 @@ def storeEphemeralPublicKey(self, publicKey):
)
self.sydent.db.commit()

def validateEphemeralPublicKey(self, publicKey):
def validateEphemeralPublicKey(self, publicKey: str) -> bool:
"""
Checks if an ephemeral public key is valid, and, if it is, updates its
verification count.
Expand All @@ -159,7 +165,7 @@ def validateEphemeralPublicKey(self, publicKey):
self.sydent.db.commit()
return cur.rowcount > 0

def getSenderForToken(self, token):
def getSenderForToken(self, token: str) -> Optional[str]:
"""
Retrieves the MXID of the user that sent the invite the provided token is for.

Expand All @@ -177,7 +183,7 @@ def getSenderForToken(self, token):
return rows[0][0]
return None

def deleteTokens(self, medium, address):
def deleteTokens(self, medium: str, address: str) -> None:
"""
Deletes every token for a given 3PID.

Expand Down
17 changes: 11 additions & 6 deletions sydent/db/peers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
# limitations under the License.
from __future__ import absolute_import

from typing import TYPE_CHECKING, Dict, List, Optional

from sydent.replication.peer import RemotePeer

if TYPE_CHECKING:
from sydent.sydent import Sydent


class PeerStore:
def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
self.sydent = sydent

def getPeerByName(self, name):
def getPeerByName(self, name: str) -> Optional[RemotePeer]:
"""
Retrieves a remote peer using it's server name.

Expand Down Expand Up @@ -57,7 +62,7 @@ def getPeerByName(self, name):

return p

def getAllPeers(self):
def getAllPeers(self) -> List[RemotePeer]:
"""
Retrieve all of the remote peers from the database.

Expand All @@ -75,7 +80,7 @@ def getAllPeers(self):
peername = None
port = None
lastSentVer = None
pubkeys = {}
pubkeys: Dict[str, str] = {}

for row in res.fetchall():
if row[0] != peername:
Expand All @@ -95,8 +100,8 @@ def getAllPeers(self):
return peers

def setLastSentVersionAndPokeSucceeded(
self, peerName, lastSentVersion, lastPokeSucceeded
):
self, peerName: str, lastSentVersion: int, lastPokeSucceeded: int
) -> None:
"""
Sets the ID of the last association sent to a given peer and the time of the
last successful request sent to that peer.
Expand Down
6 changes: 5 additions & 1 deletion sydent/db/sqlitedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
import logging
import os
import sqlite3
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from sydent.sydent import Sydent

logger = logging.getLogger(__name__)


class SqliteDatabase:
def __init__(self, syd):
def __init__(self, syd: "Sydent") -> None:
self.sydent = syd

dbFilePath = self.sydent.cfg.get("db", "db.file")
Expand Down
11 changes: 8 additions & 3 deletions sydent/db/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List

if TYPE_CHECKING:
from sydent.sydent import Sydent


class TermsStore(object):
def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
self.sydent = sydent

def getAgreedUrls(self, user_id):
def getAgreedUrls(self, user_id: str) -> List[str]:
"""
Retrieves the URLs of the terms the given user has agreed to.

Expand All @@ -45,7 +50,7 @@ def getAgreedUrls(self, user_id):

return urls

def addAgreedUrls(self, user_id, urls):
def addAgreedUrls(self, user_id: str, urls: List[str]) -> None:
"""
Saves that the given user has accepted the terms at the given URLs.

Expand Down
Loading