Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Database storage profile passes mypy #11342

Merged
merged 2 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11342.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to storage classes.
7 changes: 6 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ exclude = (?x)
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/presence.py
|synapse/storage/databases/main/profile.py
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
Expand Down Expand Up @@ -180,6 +179,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.room_batch]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.profile]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True

Expand Down Expand Up @@ -282,6 +284,9 @@ disallow_untyped_defs = True
[mypy-tests.handlers.test_user_directory]
disallow_untyped_defs = True

[mypy-tests.storage.test_profile]
disallow_untyped_defs = True

[mypy-tests.storage.test_user_directory]
disallow_untyped_defs = True

Expand Down
12 changes: 8 additions & 4 deletions synapse/storage/databases/main/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo


Expand Down Expand Up @@ -104,7 +105,7 @@ async def update_remote_profile_cache(
desc="update_remote_profile_cache",
)

async def maybe_delete_remote_profile_cache(self, user_id):
async def maybe_delete_remote_profile_cache(self, user_id: str) -> None:
"""Check if we still care about the remote user's profile, and if we
don't then remove their profile from the cache
"""
Expand All @@ -116,9 +117,9 @@ async def maybe_delete_remote_profile_cache(self, user_id):
desc="delete_remote_profile_cache",
)

async def is_subscribed_remote_profile_for_user(self, user_id):
async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool:
"""Check whether we are interested in a remote user's profile."""
res = await self.db_pool.simple_select_one_onecol(
res: Optional[str] = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
Expand All @@ -139,13 +140,16 @@ async def is_subscribed_remote_profile_for_user(self, user_id):

if res:
return True
return False

async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked`"""

def _get_remote_profile_cache_entries_that_expire_txn(txn):
def _get_remote_profile_cache_entries_that_expire_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
sql = """
SELECT user_id, displayname, avatar_url
FROM remote_profile_cache
Expand Down
9 changes: 6 additions & 3 deletions tests/storage/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util import Clock

from tests import unittest


class ProfileStoreTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()

self.u_frank = UserID.from_string("@frank:test")

def test_displayname(self):
def test_displayname(self) -> None:
self.get_success(self.store.create_profile(self.u_frank.localpart))

self.get_success(
Expand All @@ -48,7 +51,7 @@ def test_displayname(self):
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
)

def test_avatar_url(self):
def test_avatar_url(self) -> None:
self.get_success(self.store.create_profile(self.u_frank.localpart))

self.get_success(
Expand Down