Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Type hints for RegistrationStore #8615

Merged
merged 5 commits into from
Oct 22, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8615.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Type hints for `RegistrationStore`.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ files =
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
Expand Down
1 change: 0 additions & 1 deletion synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def __init__(self, database: DatabasePool, db_conn, hs):
db_conn, "e2e_cross_signing_keys", "stream_id"
)

self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
Expand Down
154 changes: 82 additions & 72 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,33 @@
# limitations under the License.
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached

if TYPE_CHECKING:
from synapse.server import HomeServer

THIRTY_MINUTES_IN_MS = 30 * 60 * 1000

logger = logging.getLogger(__name__)


class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a type for db_conn to each of these __init__ methods while we're at it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't elsewhere, but that is a terrible reason not to do it.

super().__init__(database, db_conn, hs)

self.config = hs.config
self.clock = hs.get_clock()

# Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is
Expand All @@ -55,7 +59,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):

# Create a background job for culling expired 3PID validity tokens
if hs.config.run_background_tasks:
self.clock.looping_call(
self._clock.looping_call(
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
)

Expand Down Expand Up @@ -92,7 +96,7 @@ async def is_trial_user(self, user_id: str) -> bool:
if not info:
return False

now = self.clock.time_msec()
now = self._clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
return is_trial
Expand Down Expand Up @@ -257,7 +261,7 @@ def select_users_txn(txn, now_ms, renew_at):
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(),
self._clock.time_msec(),
self.config.account_validity.renew_at,
)

Expand Down Expand Up @@ -328,13 +332,17 @@ def set_server_admin_txn(txn):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)

def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
" access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)
sql = """
SELECT users.name,
users.is_guest,
users.shadow_banned,
access_tokens.id as token_id,
access_tokens.device_id,
access_tokens.valid_until_ms
FROM users
INNER JOIN access_tokens on users.name = access_tokens.user_id
WHERE token = ?
"""

txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
Expand Down Expand Up @@ -803,7 +811,7 @@ def cull_expired_threepid_validation_tokens_txn(txn, ts):
await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
self._clock.time_msec(),
)

@wrap_as_background_process("account_validity_set_expiration_dates")
Expand Down Expand Up @@ -890,10 +898,10 @@ async def del_user_pending_deactivation(self, user_id: str) -> None:


class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.clock = hs.get_clock()
self._clock = hs.get_clock()
self.config = hs.config

self.db_pool.updates.register_background_index_update(
Expand Down Expand Up @@ -1016,13 +1024,56 @@ def _bg_user_threepids_grandfather_txn(txn):

return 1

async def set_user_deactivated_status(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were these moved?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're used elsewhere in that class

self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.

Args:
user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""

await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
deactivated,
)

def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))

@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)

return res if res else False


class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors

self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")

async def add_access_token_to_user(
self,
user_id: str,
Expand Down Expand Up @@ -1138,19 +1189,19 @@ async def register_user(
def _register_user(
self,
txn,
user_id,
password_hash,
was_guest,
make_guest,
appservice_id,
create_profile_with_displayname,
admin,
user_type,
shadow_banned,
user_id: str,
password_hash: Optional[str],
was_guest: bool,
make_guest: bool,
appservice_id: Optional[str],
create_profile_with_displayname: Optional[str],
admin: bool,
user_type: Optional[str],
shadow_banned: bool,
):
user_id_obj = UserID.from_string(user_id)

now = int(self.clock.time())
now = int(self._clock.time())

try:
if was_guest:
Expand Down Expand Up @@ -1374,18 +1425,6 @@ def f(txn):

await self.db_pool.runInteraction("delete_access_token", f)

@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)

return res if res else False

async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
Expand Down Expand Up @@ -1479,7 +1518,7 @@ def validate_threepid_session_txn(txn):
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
updatevalues={"validated_at": self.clock.time_msec()},
updatevalues={"validated_at": self._clock.time_msec()},
)

return next_link
Expand Down Expand Up @@ -1547,35 +1586,6 @@ def start_or_continue_validation_session_txn(txn):
start_or_continue_validation_session_txn,
)

async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.

Args:
user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""

await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
deactivated,
)

def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))


def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
Expand Down