Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Type hint the constructors of the data store classes #11555

Merged
merged 3 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11555.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to storage classes.
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Optional

from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
Expand All @@ -27,7 +27,12 @@


class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen: Optional[
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import TYPE_CHECKING

from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.lrucache import LruCache

Expand All @@ -25,7 +25,12 @@


class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
Expand All @@ -27,7 +27,12 @@


class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.hs = hs
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING

from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
Expand Down Expand Up @@ -58,7 +58,12 @@ class SlavedEventStore(
RelationsWorkerStore,
BaseSlavedStore,
):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

events_max = self._stream_id_gen.get_current_token()
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import TYPE_CHECKING

from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.filtering import FilteringStore

from ._base import BaseSlavedStore
Expand All @@ -24,7 +24,12 @@


class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

# Filters are immutable so this cache doesn't need to be expired
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache

Expand All @@ -26,7 +26,12 @@


class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.hs = hs
Expand Down
13 changes: 8 additions & 5 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union

from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import get_domain_from_id
from synapse.util import json_decoder

Expand All @@ -38,7 +36,12 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""

def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def commit(self) -> None:
def rollback(self) -> None:
self.conn.rollback()

def __enter__(self) -> "Connection":
def __enter__(self) -> "LoggingDatabaseConnection":
Copy link
Contributor Author

@squahtx squahtx Dec 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return value here makes its way into the data store constructor:

with make_conn(database_config, engine, "startup") as db_conn:

main = main_store_class(database, db_conn, hs)

The constructor requires a LoggingDatabaseConnection because StreamIdGenerator constructors use the txn_name argument of cursor():

def _load_current_id(
db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
) -> int:
cur = db_conn.cursor(txn_name="_load_current_id")

def cursor(
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
) -> "LoggingTransaction":

self.conn.__enter__()
return self

Expand Down
9 changes: 7 additions & 2 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import TYPE_CHECKING, List, Optional, Tuple

from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
Expand Down Expand Up @@ -129,7 +129,12 @@ class DataStore(
LockStore,
SessionStore,
):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
Expand Down
9 changes: 7 additions & 2 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
Expand All @@ -39,7 +39,12 @@ class AccountDataWorkerStore(SQLBaseStore):
`get_max_account_data_stream_id` which can be called in the initializer.
"""

def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self._instance_name = hs.get_instance_name()

if isinstance(database.engine, PostgresEngine):
Expand Down
10 changes: 7 additions & 3 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder

Expand Down Expand Up @@ -58,7 +57,12 @@ def _make_exclusive_regex(


class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.services_cache = load_appservices(
hs.hostname, hs.config.appservice.app_service_config_files
)
Expand Down
9 changes: 7 additions & 2 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter

Expand All @@ -41,7 +41,12 @@


class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()
Expand Down
13 changes: 11 additions & 2 deletions synapse/storage/databases/main/censor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder
Expand All @@ -31,7 +35,12 @@


class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

if (
Expand Down
22 changes: 18 additions & 4 deletions synapse/storage/databases/main/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
from synapse.storage.types import Connection
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache

Expand Down Expand Up @@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict):


class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_index_update(
Expand Down Expand Up @@ -394,7 +398,12 @@ def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:


class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.user_ips_max_age = hs.config.server.user_ips_max_age
Expand Down Expand Up @@ -532,7 +541,12 @@ def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:


class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):

# (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
Expand Down
7 changes: 6 additions & 1 deletion synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"

def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_index_update(
Expand Down
Loading