From 954921736b88de25c775c519a206449e46b3bf07 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 14 Sep 2023 12:46:30 +0100 Subject: [PATCH] Refactor `get_user_by_id` (#16316) --- changelog.d/16316.misc | 1 + synapse/api/auth/internal.py | 2 +- synapse/api/auth/msc3861_delegated.py | 2 +- synapse/handlers/account.py | 2 +- synapse/handlers/admin.py | 49 ++++++------ synapse/handlers/message.py | 6 +- synapse/module_api/__init__.py | 4 +- synapse/rest/consent/consent_resource.py | 2 +- .../server_notices/consent_server_notices.py | 6 +- synapse/storage/databases/main/client_ips.py | 11 +++ .../storage/databases/main/registration.py | 76 ++++++------------- synapse/types/__init__.py | 10 ++- tests/api/test_auth.py | 12 ++- tests/storage/test_registration.py | 48 ++++++------ 14 files changed, 108 insertions(+), 123 deletions(-) create mode 100644 changelog.d/16316.misc diff --git a/changelog.d/16316.misc b/changelog.d/16316.misc new file mode 100644 index 000000000000..aa0644f278c4 --- /dev/null +++ b/changelog.d/16316.misc @@ -0,0 +1 @@ +Refactor `get_user_by_id`. diff --git a/synapse/api/auth/internal.py b/synapse/api/auth/internal.py index 6a5fd44ec01c..a75f6f2cc44e 100644 --- a/synapse/api/auth/internal.py +++ b/synapse/api/auth/internal.py @@ -268,7 +268,7 @@ async def get_user_by_access_token( stored_user = await self.store.get_user_by_id(user_id) if not stored_user: raise InvalidClientTokenError("Unknown user_id %s" % user_id) - if not stored_user["is_guest"]: + if not stored_user.is_guest: raise InvalidClientTokenError( "Guest access token used for regular user" ) diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py index ef5d3f9b815c..31bb035cc846 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py @@ -300,7 +300,7 @@ async def get_user_by_access_token( user_id = UserID(username, self._hostname) # First try to find a user from the username claim - user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string()) + user_info = await self.store.get_user_by_id(user_id=user_id.to_string()) if user_info is None: # If the user does not exist, we should create it on the fly # TODO: we could use SCIM to provision users ahead of time and listen diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py index c05a14304c1e..fa043cca867d 100644 --- a/synapse/handlers/account.py +++ b/synapse/handlers/account.py @@ -102,7 +102,7 @@ async def _get_local_account_status(self, user_id: UserID) -> JsonDict: """ status = {"exists": False} - userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string()) + userinfo = await self._main_store.get_user_by_id(user_id.to_string()) if userinfo is not None: status = { diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 2f0e5f3b0a9e..7092ff3449ca 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -18,7 +18,7 @@ from synapse.api.constants import Direction, Membership from synapse.events import EventBase -from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID +from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -57,38 +57,30 @@ async def get_whois(self, user: UserID) -> JsonDict: async def get_user(self, user: UserID) -> Optional[JsonDict]: """Function to get user details""" - user_info_dict = await self._store.get_user_by_id(user.to_string()) - if user_info_dict is None: + user_info: Optional[UserInfo] = await self._store.get_user_by_id( + user.to_string() + ) + if user_info is None: return None - # Restrict returned information to a known set of fields. This prevents additional - # fields added to get_user_by_id from modifying Synapse's external API surface. - user_info_to_return = { - "name", - "admin", - "deactivated", - "locked", - "shadow_banned", - "creation_ts", - "appservice_id", - "consent_server_notice_sent", - "consent_version", - "consent_ts", - "user_type", - "is_guest", - "last_seen_ts", + user_info_dict = { + "name": user.to_string(), + "admin": user_info.is_admin, + "deactivated": user_info.is_deactivated, + "locked": user_info.locked, + "shadow_banned": user_info.is_shadow_banned, + "creation_ts": user_info.creation_ts, + "appservice_id": user_info.appservice_id, + "consent_server_notice_sent": user_info.consent_server_notice_sent, + "consent_version": user_info.consent_version, + "consent_ts": user_info.consent_ts, + "user_type": user_info.user_type, + "is_guest": user_info.is_guest, } if self._msc3866_enabled: # Only include the approved flag if support for MSC3866 is enabled. - user_info_to_return.add("approved") - - # Restrict returned keys to a known set. - user_info_dict = { - key: value - for key, value in user_info_dict.items() - if key in user_info_to_return - } + user_info_dict["approved"] = user_info.approved # Add additional user metadata profile = await self._store.get_profileinfo(user) @@ -105,6 +97,9 @@ async def get_user(self, user: UserID) -> Optional[JsonDict]: user_info_dict["external_ids"] = external_ids user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) + last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string()) + user_info_dict["last_seen_ts"] = last_seen_ts + return user_info_dict async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6be18cdefff..c036578a3dce 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -828,13 +828,13 @@ async def assert_accepted_privacy_policy(self, requester: Requester) -> None: u = await self.store.get_user_by_id(user_id) assert u is not None - if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT): + if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT): # support and bot users are not required to consent return - if u["appservice_id"] is not None: + if u.appservice_id is not None: # users registered by an appservice are exempt return - if u["consent_version"] == self.config.consent.user_consent_version: + if u.consent_version == self.config.consent.user_consent_version: return consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d6efe10a28ba..7ec202be2342 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -572,7 +572,7 @@ async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: Returns: UserInfo object if a user was found, otherwise None """ - return await self._store.get_userinfo_by_id(user_id) + return await self._store.get_user_by_id(user_id) async def get_user_by_req( self, @@ -1878,7 +1878,7 @@ async def put_global( raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}") # Ensure the user exists, so we don't just write to users that aren't there. - if await self._store.get_userinfo_by_id(user_id) is None: + if await self._store.get_user_by_id(user_id) is None: raise ValueError(f"User {user_id} does not exist on this server.") await self._handler.add_account_data_for_user(user_id, data_type, new_data) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 25f9ea285bca..88d3ec1baf61 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -129,7 +129,7 @@ async def _async_render_GET(self, request: Request) -> None: if u is None: raise NotFoundError("Unknown user") - has_consented = u["consent_version"] == version + has_consented = u.consent_version == version userhmac = userhmac_bytes.decode("ascii") try: diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 94025ba41f7d..a879b6505e4e 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -79,15 +79,15 @@ async def maybe_send_server_notice_to_user(self, user_id: str) -> None: if u is None: return - if u["is_guest"] and not self._send_to_guests: + if u.is_guest and not self._send_to_guests: # don't send to guests return - if u["consent_version"] == self._current_consent_version: + if u.consent_version == self._current_consent_version: # user has already consented return - if u["consent_server_notice_sent"] == self._current_consent_version: + if u.consent_server_notice_sent == self._current_consent_version: # we've already sent a notice to the user return diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index d8d333e11d04..7da47c3dd727 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -764,3 +764,14 @@ async def get_user_ip_and_agents( } return list(results.values()) + + async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]: + """Get the last seen timestamp for a user, if we have it.""" + + return await self.db_pool.simple_select_one_onecol( + table="user_ips", + keyvalues={"user_id": user_id}, + retcol="MAX(last_seen)", + allow_none=True, + desc="get_last_seen_for_user_id", + ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e34156dc5584..cc964604e283 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ import logging import random import re -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr @@ -192,8 +192,8 @@ def __init__( ) @cached() - async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]: - """Deprecated: use get_userinfo_by_id instead""" + async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]: + """Returns info about the user account, if it exists.""" def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: # We could technically use simple_select_one here, but it would not perform @@ -202,16 +202,12 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: txn.execute( """ SELECT - name, password_hash, is_guest, admin, consent_version, consent_ts, + name, is_guest, admin, consent_version, consent_ts, consent_server_notice_sent, appservice_id, creation_ts, user_type, deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, COALESCE(approved, TRUE) AS approved, - COALESCE(locked, FALSE) AS locked, last_seen_ts + COALESCE(locked, FALSE) AS locked FROM users - LEFT JOIN ( - SELECT user_id, MAX(last_seen) AS last_seen_ts - FROM user_ips GROUP BY user_id - ) ls ON users.name = ls.user_id WHERE name = ? """, (user_id,), @@ -228,51 +224,23 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: desc="get_user_by_id", func=get_user_by_id_txn, ) - - if row is not None: - # If we're using SQLite our boolean values will be integers. Because we - # present some of this data as is to e.g. server admins via REST APIs, we - # want to make sure we're returning the right type of data. - # Note: when adding a column name to this list, be wary of NULLable columns, - # since NULL values will be turned into False. - boolean_columns = [ - "admin", - "deactivated", - "shadow_banned", - "approved", - "locked", - ] - for column in boolean_columns: - row[column] = bool(row[column]) - - return row - - async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: - """Get a UserInfo object for a user by user ID. - - Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, - this method should be cached. - - Args: - user_id: The user to fetch user info for. - Returns: - `UserInfo` object if user found, otherwise `None`. - """ - user_data = await self.get_user_by_id(user_id) - if not user_data: + if row is None: return None + return UserInfo( - appservice_id=user_data["appservice_id"], - consent_server_notice_sent=user_data["consent_server_notice_sent"], - consent_version=user_data["consent_version"], - creation_ts=user_data["creation_ts"], - is_admin=bool(user_data["admin"]), - is_deactivated=bool(user_data["deactivated"]), - is_guest=bool(user_data["is_guest"]), - is_shadow_banned=bool(user_data["shadow_banned"]), - user_id=UserID.from_string(user_data["name"]), - user_type=user_data["user_type"], - last_seen_ts=user_data["last_seen_ts"], + appservice_id=row["appservice_id"], + consent_server_notice_sent=row["consent_server_notice_sent"], + consent_version=row["consent_version"], + consent_ts=row["consent_ts"], + creation_ts=row["creation_ts"], + is_admin=bool(row["admin"]), + is_deactivated=bool(row["deactivated"]), + is_guest=bool(row["is_guest"]), + is_shadow_banned=bool(row["shadow_banned"]), + user_id=UserID.from_string(row["name"]), + user_type=row["user_type"], + approved=bool(row["approved"]), + locked=bool(row["locked"]), ) async def is_trial_user(self, user_id: str) -> bool: @@ -290,10 +258,10 @@ async def is_trial_user(self, user_id: str) -> bool: now = self._clock.time_msec() days = self.config.server.mau_appservice_trial_days.get( - info["appservice_id"], self.config.server.mau_trial_days + info.appservice_id, self.config.server.mau_trial_days ) trial_duration_ms = days * 24 * 60 * 60 * 1000 - is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms + is_trial = (now - info.creation_ts * 1000) < trial_duration_ms return is_trial @cached() diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 488714f60cb6..76b0e3e694f7 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key( @attr.s(auto_attribs=True, frozen=True, slots=True) class UserInfo: - """Holds information about a user. Result of get_userinfo_by_id. + """Holds information about a user. Result of get_user_by_id. Attributes: user_id: ID of the user. appservice_id: Application service ID that created this user. consent_server_notice_sent: Version of policy documents the user has been sent. consent_version: Version of policy documents the user has consented to. + consent_ts: Time the user consented creation_ts: Creation timestamp of the user. is_admin: True if the user is an admin. is_deactivated: True if the user has been deactivated. is_guest: True if the user is a guest user. is_shadow_banned: True if the user has been shadow-banned. user_type: User type (None for normal user, 'support' and 'bot' other options). - last_seen_ts: Last activity timestamp of the user. + approved: If the user has been "approved" to register on the server. + locked: Whether the user's account has been locked """ user_id: UserID appservice_id: Optional[int] consent_server_notice_sent: Optional[str] consent_version: Optional[str] + consent_ts: Optional[int] user_type: Optional[str] creation_ts: int is_admin: bool is_deactivated: bool is_guest: bool is_shadow_banned: bool - last_seen_ts: Optional[int] + approved: bool + locked: bool class UserProfile(TypedDict): diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index dcd01d56885c..e00d7215dfeb 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -188,8 +188,11 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None: ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - # This just needs to return a truth-y value. - self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) + + class FakeUserInfo: + is_guest = False + + self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo()) self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) @@ -341,7 +344,10 @@ def test_get_user_from_macaroon(self) -> None: ) def test_get_guest_user_from_macaroon(self) -> None: - self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True}) + class FakeUserInfo: + is_guest = True + + self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo()) self.store.get_user_by_access_token = AsyncMock(return_value=None) user_id = "@baldrick:matrix.org" diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 95c9792d546e..0cca34d355f6 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -16,7 +16,7 @@ from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, UserInfo from synapse.util import Clock from tests.unittest import HomeserverTestCase, override_config @@ -35,24 +35,22 @@ def test_register(self) -> None: self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.assertEqual( - { + UserInfo( # TODO(paul): Surely this field should be 'user_id', not 'name' - "name": self.user_id, - "password_hash": self.pwhash, - "admin": 0, - "is_guest": 0, - "consent_version": None, - "consent_ts": None, - "consent_server_notice_sent": None, - "appservice_id": None, - "creation_ts": 0, - "user_type": None, - "deactivated": 0, - "locked": 0, - "shadow_banned": 0, - "approved": 1, - "last_seen_ts": None, - }, + user_id=UserID.from_string(self.user_id), + is_admin=False, + is_guest=False, + consent_server_notice_sent=None, + consent_ts=None, + consent_version=None, + appservice_id=None, + creation_ts=0, + user_type=None, + is_deactivated=False, + locked=False, + is_shadow_banned=False, + approved=True, + ), (self.get_success(self.store.get_user_by_id(self.user_id))), ) @@ -65,9 +63,11 @@ def test_consent(self) -> None: user = self.get_success(self.store.get_user_by_id(self.user_id)) assert user - self.assertEqual(user["consent_version"], "1") - self.assertGreater(user["consent_ts"], before_consent) - self.assertLess(user["consent_ts"], self.clock.time_msec()) + self.assertEqual(user.consent_version, "1") + self.assertIsNotNone(user.consent_ts) + assert user.consent_ts is not None + self.assertGreater(user.consent_ts, before_consent) + self.assertLess(user.consent_ts, self.clock.time_msec()) def test_add_tokens(self) -> None: self.get_success(self.store.register_user(self.user_id, self.pwhash)) @@ -215,7 +215,7 @@ def test_approval_not_required(self) -> None: user = self.get_success(self.store.get_user_by_id(self.user_id)) assert user is not None - self.assertTrue(user["approved"]) + self.assertTrue(user.approved) approved = self.get_success(self.store.is_user_approved(self.user_id)) self.assertTrue(approved) @@ -228,7 +228,7 @@ def test_approval_required(self) -> None: user = self.get_success(self.store.get_user_by_id(self.user_id)) assert user is not None - self.assertFalse(user["approved"]) + self.assertFalse(user.approved) approved = self.get_success(self.store.is_user_approved(self.user_id)) self.assertFalse(approved) @@ -248,7 +248,7 @@ def test_override(self) -> None: user = self.get_success(self.store.get_user_by_id(self.user_id)) self.assertIsNotNone(user) assert user is not None - self.assertEqual(user["approved"], 1) + self.assertEqual(user.approved, 1) approved = self.get_success(self.store.is_user_approved(self.user_id)) self.assertTrue(approved)