From 52ac3db5038797d414bab4b4a80e0929c8a510f7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 21 Oct 2020 10:32:37 +0100 Subject: [PATCH 1/5] Format SQL --- synapse/storage/databases/main/registration.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 4c843b76798c..80d9606783cf 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -328,13 +328,17 @@ def set_server_admin_txn(txn): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," - " access_tokens.device_id, access_tokens.valid_until_ms" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) + sql = """ + SELECT users.name, + users.is_guest, + users.shadow_banned, + access_tokens.id as token_id, + access_tokens.device_id, + access_tokens.valid_until_ms + FROM users + INNER JOIN access_tokens on users.name = access_tokens.user_id + WHERE token = ? + """ txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) From 3bbac3ff9f62073ed6b7109e143b42c8780c3a09 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 21 Oct 2020 10:41:12 +0100 Subject: [PATCH 2/5] Add registration store to mypy --- mypy.ini | 1 + synapse/storage/databases/main/__init__.py | 1 - .../storage/databases/main/registration.py | 92 ++++++++++--------- 3 files changed, 49 insertions(+), 45 deletions(-) diff --git a/mypy.ini b/mypy.ini index 5e9f7b1259e4..59d9074c3b06 100644 --- a/mypy.ini +++ b/mypy.ini @@ -57,6 +57,7 @@ files = synapse/spam_checker_api, synapse/state, synapse/storage/databases/main/events.py, + synapse/storage/databases/main/registration.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 9b16f45f3eff..43660ec4fb5c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -146,7 +146,6 @@ def __init__(self, database: DatabasePool, db_conn, hs): db_conn, "e2e_cross_signing_keys", "stream_id" ) - self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 80d9606783cf..d856b769527b 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -21,9 +21,11 @@ from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID from synapse.util.caches.descriptors import cached @@ -33,7 +35,7 @@ logger = logging.getLogger(__name__) -class RegistrationWorkerStore(SQLBaseStore): +class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -1020,13 +1022,56 @@ def _bg_user_threepids_grandfather_txn(txn): return 1 + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: + """Set the `deactivated` property for the provided user to the provided value. + + Args: + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. + """ + + await self.db_pool.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) + + def set_user_deactivated_status_txn(self, txn, user_id, deactivated): + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) + txn.call_after(self.is_guest.invalidate, (user_id,)) -class RegistrationStore(RegistrationBackgroundUpdateStore): + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="is_guest", + allow_none=True, + desc="is_guest", + ) + + return res if res else False + + +class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + async def add_access_token_to_user( self, user_id: str, @@ -1378,18 +1423,6 @@ def f(txn): await self.db_pool.runInteraction("delete_access_token", f) - @cached() - async def is_guest(self, user_id: str) -> bool: - res = await self.db_pool.simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="is_guest", - allow_none=True, - desc="is_guest", - ) - - return res if res else False - async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're @@ -1551,35 +1584,6 @@ def start_or_continue_validation_session_txn(txn): start_or_continue_validation_session_txn, ) - async def set_user_deactivated_status( - self, user_id: str, deactivated: bool - ) -> None: - """Set the `deactivated` property for the provided user to the provided value. - - Args: - user_id: The ID of the user to set the status for. - deactivated: The value to set for `deactivated`. - """ - - await self.db_pool.runInteraction( - "set_user_deactivated_status", - self.set_user_deactivated_status_txn, - user_id, - deactivated, - ) - - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"deactivated": 1 if deactivated else 0}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,) - ) - txn.call_after(self.is_guest.invalidate, (user_id,)) - def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ From 94f6436c3638d90fcb5c46652b60ad4ffa6c1492 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 21 Oct 2020 10:48:44 +0100 Subject: [PATCH 3/5] Add typing info to registration --- .../storage/databases/main/registration.py | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index d856b769527b..f2b7233d293c 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ # limitations under the License. import logging import re -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -30,17 +30,19 @@ from synapse.types import UserID from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 logger = logging.getLogger(__name__) class RegistrationWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config - self.clock = hs.get_clock() # Note: we don't check this sequence for consistency as we'd have to # call `find_max_generated_user_id_localpart` each time, which is @@ -57,7 +59,7 @@ def __init__(self, database: DatabasePool, db_conn, hs): # Create a background job for culling expired 3PID validity tokens if hs.config.run_background_tasks: - self.clock.looping_call( + self._clock.looping_call( self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS ) @@ -94,7 +96,7 @@ async def is_trial_user(self, user_id: str) -> bool: if not info: return False - now = self.clock.time_msec() + now = self._clock.time_msec() trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial @@ -259,7 +261,7 @@ def select_users_txn(txn, now_ms, renew_at): return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, - self.clock.time_msec(), + self._clock.time_msec(), self.config.account_validity.renew_at, ) @@ -809,7 +811,7 @@ def cull_expired_threepid_validation_tokens_txn(txn, ts): await self.db_pool.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, - self.clock.time_msec(), + self._clock.time_msec(), ) @wrap_as_background_process("account_validity_set_expiration_dates") @@ -896,10 +898,10 @@ async def del_user_pending_deactivation(self, user_id: str) -> None: class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self.clock = hs.get_clock() + self._clock = hs.get_clock() self.config = hs.config self.db_pool.updates.register_background_index_update( @@ -1039,7 +1041,7 @@ async def set_user_deactivated_status( deactivated, ) - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): + def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool): self.db_pool.simple_update_one_txn( txn=txn, table="users", @@ -1065,7 +1067,7 @@ async def is_guest(self, user_id: str) -> bool: class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors @@ -1187,19 +1189,19 @@ async def register_user( def _register_user( self, txn, - user_id, - password_hash, - was_guest, - make_guest, - appservice_id, - create_profile_with_displayname, - admin, - user_type, - shadow_banned, + user_id: str, + password_hash: Optional[str], + was_guest: bool, + make_guest: bool, + appservice_id: Optional[str], + create_profile_with_displayname: Optional[str], + admin: bool, + user_type: Optional[str], + shadow_banned: bool, ): user_id_obj = UserID.from_string(user_id) - now = int(self.clock.time()) + now = int(self._clock.time()) try: if was_guest: @@ -1516,7 +1518,7 @@ def validate_threepid_session_txn(txn): txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self.clock.time_msec()}, + updatevalues={"validated_at": self._clock.time_msec()}, ) return next_link From 94dcb64c2a20c98dc801a753827ad4350f9a4351 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 21 Oct 2020 17:10:32 +0100 Subject: [PATCH 4/5] Newsfile --- changelog.d/8615.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/8615.misc diff --git a/changelog.d/8615.misc b/changelog.d/8615.misc new file mode 100644 index 000000000000..79fa7b7ff84d --- /dev/null +++ b/changelog.d/8615.misc @@ -0,0 +1 @@ +Type hints for `RegistrationStore`. From 03e073760215fac6f71d25932a4fa007e263a935 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Oct 2020 10:13:52 +0100 Subject: [PATCH 5/5] Add types to db_conn --- synapse/storage/databases/main/registration.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index f2b7233d293c..b0329e17ec6e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -24,7 +24,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.stats import StatsStore -from synapse.storage.types import Cursor +from synapse.storage.types import Connection, Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID @@ -39,7 +39,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config @@ -898,7 +898,7 @@ async def del_user_pending_deactivation(self, user_id: str) -> None: class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() @@ -1067,7 +1067,7 @@ async def is_guest(self, user_id: str) -> bool: class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors