Skip to content

Commit

Permalink
Add type hints for checking by mypy (#355)
Browse files Browse the repository at this point in the history
Signed-off-by H-Shay: <shaysquared@gmail.com>
  • Loading branch information
H-Shay authored Jun 9, 2021
1 parent af32e62 commit 733272a
Show file tree
Hide file tree
Showing 34 changed files with 380 additions and 158 deletions.
1 change: 1 addition & 0 deletions changelog.d/355.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added type hints to support mypy checks.
19 changes: 13 additions & 6 deletions sydent/db/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
# limitations under the License.
from __future__ import absolute_import

from typing import TYPE_CHECKING, Optional

from sydent.users.accounts import Account

if TYPE_CHECKING:
from sydent.sydent import Sydent


class AccountStore(object):
def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
self.sydent = sydent

def getAccountByToken(self, token):
def getAccountByToken(self, token: str) -> Optional[Account]:
"""
Select the account matching the given token, if any.
Expand All @@ -45,7 +50,9 @@ def getAccountByToken(self, token):

return Account(*row)

def storeAccount(self, user_id, creation_ts, consent_version):
def storeAccount(
self, user_id: str, creation_ts: int, consent_version: Optional[str]
) -> None:
"""
Stores an account for the given user ID.
Expand All @@ -65,7 +72,7 @@ def storeAccount(self, user_id, creation_ts, consent_version):
)
self.sydent.db.commit()

def setConsentVersion(self, user_id, consent_version):
def setConsentVersion(self, user_id: str, consent_version: Optional[str]) -> None:
"""
Saves that the given user has agreed to all of the terms in the document of the
given version.
Expand All @@ -82,7 +89,7 @@ def setConsentVersion(self, user_id, consent_version):
)
self.sydent.db.commit()

def addToken(self, user_id, token):
def addToken(self, user_id: str, token: str) -> None:
"""
Stores the authentication token for a given user.
Expand All @@ -98,7 +105,7 @@ def addToken(self, user_id, token):
)
self.sydent.db.commit()

def delToken(self, token):
def delToken(self, token: str) -> int:
"""
Deletes an authentication token from the database.
Expand Down
21 changes: 17 additions & 4 deletions sydent/db/hashing_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

# Actions on the hashing_metadata table which is defined in the migration process in
# sqlitedb.py
from sqlite3 import Cursor
from typing import TYPE_CHECKING, Callable, Optional

if TYPE_CHECKING:
from sydent.sydent import Sydent


class HashingMetadataStore:
def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
self.sydent = sydent

def get_lookup_pepper(self):
def get_lookup_pepper(self) -> Optional[str]:
"""Return the value of the current lookup pepper from the db
:return: A pepper if it exists in the database, or None if one does
Expand All @@ -44,7 +49,9 @@ def get_lookup_pepper(self):

return pepper

def store_lookup_pepper(self, hashing_function, pepper):
def store_lookup_pepper(
self, hashing_function: Callable[[str], str], pepper: str
) -> None:
"""Stores a new lookup pepper in the hashing_metadata db table and rehashes all 3PIDs
:param hashing_function: A function with single input and output strings
Expand Down Expand Up @@ -74,7 +81,13 @@ def store_lookup_pepper(self, hashing_function, pepper):
# Commit the queued db transactions so that adding a new pepper and hashing is atomic
self.sydent.db.commit()

def _rehash_threepids(self, cur, hashing_function, pepper, table):
def _rehash_threepids(
self,
cur: Cursor,
hashing_function: Callable[[str], str],
pepper: str,
table: str,
) -> None:
"""Rehash 3PIDs of a given table using a given hashing_function and pepper
A database cursor `cur` must be passed to this function. After this function completes,
Expand Down
22 changes: 14 additions & 8 deletions sydent/db/invite_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import TYPE_CHECKING, Dict, List, Optional

if TYPE_CHECKING:
from sydent.sydent import Sydent


class JoinTokenStore(object):
def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
self.sydent = sydent

def storeToken(self, medium, address, roomId, sender, token):
def storeToken(
self, medium: str, address: str, roomId: str, sender: str, token: str
) -> None:
"""
Store a new invite token and its metadata.
Expand All @@ -45,7 +51,7 @@ def storeToken(self, medium, address, roomId, sender, token):
)
self.sydent.db.commit()

def getTokens(self, medium, address):
def getTokens(self, medium: str, address: str) -> List[Dict[str, str]]:
"""
Retrieves the pending invites tokens for this 3PID that haven't been delivered
yet.
Expand Down Expand Up @@ -100,7 +106,7 @@ def getTokens(self, medium, address):

return ret

def markTokensAsSent(self, medium, address):
def markTokensAsSent(self, medium: str, address: str) -> None:
"""
Updates the invite tokens associated with a given 3PID to mark them as
delivered to a homeserver so they're not delivered again in the future.
Expand All @@ -122,7 +128,7 @@ def markTokensAsSent(self, medium, address):
)
self.sydent.db.commit()

def storeEphemeralPublicKey(self, publicKey):
def storeEphemeralPublicKey(self, publicKey: str) -> None:
"""
Saves the provided ephemeral public key.
Expand All @@ -138,7 +144,7 @@ def storeEphemeralPublicKey(self, publicKey):
)
self.sydent.db.commit()

def validateEphemeralPublicKey(self, publicKey):
def validateEphemeralPublicKey(self, publicKey: str) -> bool:
"""
Checks if an ephemeral public key is valid, and, if it is, updates its
verification count.
Expand All @@ -159,7 +165,7 @@ def validateEphemeralPublicKey(self, publicKey):
self.sydent.db.commit()
return cur.rowcount > 0

def getSenderForToken(self, token):
def getSenderForToken(self, token: str) -> Optional[str]:
"""
Retrieves the MXID of the user that sent the invite the provided token is for.
Expand All @@ -177,7 +183,7 @@ def getSenderForToken(self, token):
return rows[0][0]
return None

def deleteTokens(self, medium, address):
def deleteTokens(self, medium: str, address: str) -> None:
"""
Deletes every token for a given 3PID.
Expand Down
17 changes: 11 additions & 6 deletions sydent/db/peers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
# limitations under the License.
from __future__ import absolute_import

from typing import TYPE_CHECKING, Dict, List, Optional

from sydent.replication.peer import RemotePeer

if TYPE_CHECKING:
from sydent.sydent import Sydent


class PeerStore:
def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
self.sydent = sydent

def getPeerByName(self, name):
def getPeerByName(self, name: str) -> Optional[RemotePeer]:
"""
Retrieves a remote peer using it's server name.
Expand Down Expand Up @@ -57,7 +62,7 @@ def getPeerByName(self, name):

return p

def getAllPeers(self):
def getAllPeers(self) -> List[RemotePeer]:
"""
Retrieve all of the remote peers from the database.
Expand All @@ -75,7 +80,7 @@ def getAllPeers(self):
peername = None
port = None
lastSentVer = None
pubkeys = {}
pubkeys: Dict[str, str] = {}

for row in res.fetchall():
if row[0] != peername:
Expand All @@ -95,8 +100,8 @@ def getAllPeers(self):
return peers

def setLastSentVersionAndPokeSucceeded(
self, peerName, lastSentVersion, lastPokeSucceeded
):
self, peerName: str, lastSentVersion: int, lastPokeSucceeded: int
) -> None:
"""
Sets the ID of the last association sent to a given peer and the time of the
last successful request sent to that peer.
Expand Down
6 changes: 5 additions & 1 deletion sydent/db/sqlitedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
import logging
import os
import sqlite3
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from sydent.sydent import Sydent

logger = logging.getLogger(__name__)


class SqliteDatabase:
def __init__(self, syd):
def __init__(self, syd: "Sydent") -> None:
self.sydent = syd

dbFilePath = self.sydent.cfg.get("db", "db.file")
Expand Down
11 changes: 8 additions & 3 deletions sydent/db/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List

if TYPE_CHECKING:
from sydent.sydent import Sydent


class TermsStore(object):
def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
self.sydent = sydent

def getAgreedUrls(self, user_id):
def getAgreedUrls(self, user_id: str) -> List[str]:
"""
Retrieves the URLs of the terms the given user has agreed to.
Expand All @@ -45,7 +50,7 @@ def getAgreedUrls(self, user_id):

return urls

def addAgreedUrls(self, user_id, urls):
def addAgreedUrls(self, user_id: str, urls: List[str]) -> None:
"""
Saves that the given user has accepted the terms at the given URLs.
Expand Down
Loading

0 comments on commit 733272a

Please sign in to comment.