Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to some storage classes #11307

Merged
merged 8 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11307.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to storage classes.
7 changes: 0 additions & 7 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ exclude = (?x)
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/account_data.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/censor_events.py
|synapse/storage/databases/main/deviceinbox.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/directory.py
|synapse/storage/databases/main/e2e_room_keys.py
Expand All @@ -38,19 +36,15 @@ exclude = (?x)
|synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/events_forward_extremities.py
|synapse/storage/databases/main/events_worker.py
|synapse/storage/databases/main/filtering.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/lock.py
|synapse/storage/databases/main/media_repository.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/openid.py
|synapse/storage/databases/main/presence.py
|synapse/storage/databases/main/profile.py
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/rejections.py
|synapse/storage/databases/main/room.py
|synapse/storage/databases/main/room_batch.py
|synapse/storage/databases/main/roommember.py
Expand All @@ -59,7 +53,6 @@ exclude = (?x)
|synapse/storage/databases/main/state.py
|synapse/storage/databases/main/state_deltas.py
|synapse/storage/databases/main/stats.py
|synapse/storage/databases/main/tags.py
|synapse/storage/databases/main/transactions.py
|synapse/storage/databases/main/user_directory.py
|synapse/storage/databases/main/user_erasure_store.py
Expand Down
30 changes: 16 additions & 14 deletions synapse/storage/databases/main/censor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder
Expand All @@ -41,7 +41,7 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)

@wrap_as_background_process("_censor_redactions")
async def _censor_redactions(self):
async def _censor_redactions(self) -> None:
"""Censors all redactions older than the configured period that haven't
been censored yet.

Expand Down Expand Up @@ -105,7 +105,7 @@ async def _censor_redactions(self):
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
pruned_json = json_encoder.encode(
pruned_json: Optional[str] = json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
Expand All @@ -116,7 +116,7 @@ async def _censor_redactions(self):

updates.append((redaction_id, event_id, pruned_json))

def _update_censor_txn(txn):
def _update_censor_txn(txn: LoggingTransaction) -> None:
for redaction_id, event_id, pruned_json in updates:
if pruned_json:
self._censor_event_txn(txn, event_id, pruned_json)
Expand All @@ -130,14 +130,16 @@ def _update_censor_txn(txn):

await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)

def _censor_event_txn(self, txn, event_id, pruned_json):
def _censor_event_txn(
self, txn: LoggingTransaction, event_id: str, pruned_json: str
) -> None:
"""Censor an event by replacing its JSON in the event_json table with the
provided pruned JSON.

Args:
txn (LoggingTransaction): The database transaction.
event_id (str): The ID of the event to censor.
pruned_json (str): The pruned JSON
txn: The database transaction.
event_id: The ID of the event to censor.
pruned_json: The pruned JSON
"""
self.db_pool.simple_update_one_txn(
txn,
Expand All @@ -157,7 +159,7 @@ async def expire_event(self, event_id: str) -> None:
# Try to retrieve the event's content from the database or the event cache.
event = await self.get_event(event_id)

def delete_expired_event_txn(txn):
def delete_expired_event_txn(txn: LoggingTransaction) -> None:
# Delete the expiry timestamp associated with this event from the database.
self._delete_event_expiry_txn(txn, event_id)

Expand Down Expand Up @@ -194,14 +196,14 @@ def delete_expired_event_txn(txn):
"delete_expired_event", delete_expired_event_txn
)

def _delete_event_expiry_txn(self, txn, event_id):
def _delete_event_expiry_txn(self, txn: LoggingTransaction, event_id: str) -> None:
"""Delete the expiry timestamp associated with an event ID without deleting the
actual event.

Args:
txn (LoggingTransaction): The transaction to use to perform the deletion.
event_id (str): The event ID to delete the associated expiry timestamp of.
txn: The transaction to use to perform the deletion.
event_id: The event ID to delete the associated expiry timestamp of.
"""
return self.db_pool.simple_delete_txn(
self.db_pool.simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
)
52 changes: 36 additions & 16 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,9 +20,17 @@
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
Expand All @@ -34,14 +43,21 @@


class DeviceInboxWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()

# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
self._last_device_delete_cache: ExpiringCache[
Tuple[str, Optional[str]], int
] = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
Expand All @@ -53,14 +69,16 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name in hs.config.worker.writers.to_device
)

self._device_inbox_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
self._device_inbox_id_gen: AbstractStreamIdGenerator = (
MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
)
)
else:
self._can_write_to_device = True
Expand Down Expand Up @@ -101,6 +119,8 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):

def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
# If replication is happening than postgres must be being used.
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
Comment on lines +122 to +123
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rather unfortunate since this code can only be called if replication is in use, but not sure if there's a "better" way to do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a better way either!

self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
Expand Down Expand Up @@ -220,11 +240,11 @@ def delete_messages_for_device_txn(txn):
log_kv({"message": f"deleted {count} messages for device", "count": count})

# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
updated_last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id
updated_last_deleted_stream_id, up_to_stream_id
)

return count
Expand Down Expand Up @@ -432,7 +452,7 @@ def add_messages_txn(txn, now_ms, stream_id):
)

async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
now_ms = self._clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
Expand Down Expand Up @@ -483,7 +503,7 @@ def add_messages_txn(txn, now_ms, stream_id):
)

async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
now_ms = self._clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
Expand Down
6 changes: 4 additions & 2 deletions synapse/storage/databases/main/filtering.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,6 +19,7 @@

from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached

Expand Down Expand Up @@ -49,7 +51,7 @@ async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> i

# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
def _do_txn(txn):
def _do_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?"
Expand All @@ -61,7 +63,7 @@ def _do_txn(txn):

sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
max_id = txn.fetchone()[0] # type: ignore[index]
if max_id is None:
filter_id = 0
else:
Expand Down
6 changes: 4 additions & 2 deletions synapse/storage/databases/main/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
from typing import TYPE_CHECKING, Optional, Tuple, Type
from weakref import WeakValueDictionary

from twisted.internet.interfaces import IReactorCore
Expand Down Expand Up @@ -62,7 +62,9 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"

# A map from `(lock_name, lock_key)` to the token of any locks that we
# think we currently hold.
self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
self._live_tokens: WeakValueDictionary[
Tuple[str, str], Lock
] = WeakValueDictionary()

# When we shut down we want to remove the locks. Technically this can
# lead to a race, as we may drop the lock while we are still processing.
Expand Down
17 changes: 16 additions & 1 deletion synapse/storage/databases/main/openid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction


class OpenIdStore(SQLBaseStore):
Expand All @@ -20,7 +35,7 @@ async def insert_open_id_token(
async def get_user_id_for_open_id_token(
self, token: str, ts_now_ms: int
) -> Optional[str]:
def get_user_id_for_token_txn(txn):
def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]:
sql = (
"SELECT user_id FROM open_id_tokens"
" WHERE token = ? AND ? <= ts_valid_until_ms"
Expand Down
Loading