Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Update some type hints in storage classes #11652

Merged
merged 3 commits into from
Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11652.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to storage classes.
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, cast

from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
Expand Down Expand Up @@ -673,7 +673,7 @@ def _remove_dead_devices_from_device_inbox_txn(
# There's a type mismatch here between how we want to type the row and
# what fetchone says it returns, but we silence it because we know that
# res can't be None.
res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment]
res = cast(Tuple[Optional[int]], txn.fetchone())
if res[0] is None:
# this can only happen if the `device_inbox` table is empty, in which
# case we have no work to do.
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def _get_auth_chain_ids_txn(
new_front = set()
for chunk in batch_iter(front, 100):
# Pull the auth events either from the cache or DB.
to_fetch = [] # Event IDs to fetch from DB # type: List[str]
to_fetch: List[str] = [] # Event IDs to fetch from DB
for event_id in chunk:
res = self._event_auth_cache.get(event_id)
if res is None:
Expand Down Expand Up @@ -615,8 +615,8 @@ def _get_auth_chain_difference_txn(
# currently walking, either from cache or DB.
search, chunk = search[:-100], search[-100:]

found = [] # Results found # type: List[Tuple[str, str, int]]
to_fetch = [] # Event IDs to fetch from DB # type: List[str]
found: List[Tuple[str, str, int]] = [] # Results found
to_fetch: List[str] = [] # Event IDs to fetch from DB
for _, event_id in chunk:
res = self._event_auth_cache.get(event_id)
if res is None:
Expand Down
18 changes: 10 additions & 8 deletions synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast

import attr

Expand Down Expand Up @@ -326,7 +326,7 @@ def get_after_receipt(
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall() # type: ignore[return-value]
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())

after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
Expand Down Expand Up @@ -357,7 +357,7 @@ def get_no_receipt(
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall() # type: ignore[return-value]
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())

no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
Expand Down Expand Up @@ -434,7 +434,7 @@ def get_after_receipt(
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall() # type: ignore[return-value]
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())

after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
Expand Down Expand Up @@ -465,7 +465,7 @@ def get_no_receipt(
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall() # type: ignore[return-value]
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())

no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
Expand Down Expand Up @@ -662,7 +662,7 @@ def _find_first_stream_ordering_after_ts_txn(
The stream ordering
"""
txn.execute("SELECT MAX(stream_ordering) FROM events")
max_stream_ordering = txn.fetchone()[0] # type: ignore[index]
max_stream_ordering = cast(Tuple[Optional[int]], txn.fetchone())[0]

if max_stream_ordering is None:
return 0
Expand Down Expand Up @@ -731,7 +731,7 @@ def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
return txn.fetchone() # type: ignore[return-value]
return cast(Optional[Tuple[int]], txn.fetchone())

result = await self.db_pool.runInteraction(
"get_time_of_last_push_action_before", f
Expand Down Expand Up @@ -1029,7 +1029,9 @@ def f(
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
return txn.fetchall() # type: ignore[return-value]
return cast(
List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall()
)

push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
return [
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from typing import Optional, Tuple, Union, cast

from canonicaljson import encode_canonical_json

Expand Down Expand Up @@ -63,7 +63,7 @@ def _do_txn(txn: LoggingTransaction) -> int:

sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0] # type: ignore[index]
max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
if max_id is None:
filter_id = 0
else:
Expand Down
3 changes: 2 additions & 1 deletion synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Optional,
Tuple,
Union,
cast,
)

from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -220,7 +221,7 @@ def get_local_media_by_user_paginate_txn(
WHERE user_id = ?
"""
txn.execute(sql, args)
count = txn.fetchone()[0] # type: ignore[index]
count = cast(Tuple[int], txn.fetchone())[0]

sql = """
SELECT
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ async def add_pusher(
# invalidate, since we the user might not have had a pusher before
await self.db_pool.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream, # type: ignore
self._invalidate_cache_and_stream, # type: ignore[attr-defined]
self.get_if_user_has_pusher,
(user_id,),
)
Expand All @@ -503,7 +503,7 @@ async def delete_pusher_by_app_id_pushkey_user_id(
self, app_id: str, pushkey: str, user_id: str
) -> None:
def delete_pusher_txn(txn, stream_id):
self._invalidate_cache_and_stream( # type: ignore
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)

Expand Down Expand Up @@ -548,7 +548,7 @@ async def delete_all_pushers_for_user(self, user_id: str) -> None:
pushers = list(await self.get_pushers_by_user_id(user_id))

def delete_pushers_txn(txn, stream_ids):
self._invalidate_cache_and_stream( # type: ignore
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)

Expand Down
17 changes: 10 additions & 7 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast

import attr

Expand Down Expand Up @@ -1357,12 +1357,15 @@ def _use_registration_token_txn(txn):
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
) # type: ignore
res = cast(
Dict[str, Any],
self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
),
)
Comment on lines +1360 to +1368
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to remind myself why we needed a cast here -- it is because simple_select_one_txn is a staticmethod and overrides don't work for it.

It seems the other methods that do this are class methods instead....


# Decrement pending and increment completed
self.db_pool.simple_update_one_txn(
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, cast

import attr

Expand Down Expand Up @@ -399,7 +399,7 @@ def _get_thread_summary_txn(
AND relation_type = ?
"""
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
count = txn.fetchone()[0] # type: ignore[index]
count = cast(Tuple[int], txn.fetchone())[0]

return count, latest_event_id

Expand Down
15 changes: 9 additions & 6 deletions synapse/storage/databases/main/ui_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import attr

Expand Down Expand Up @@ -225,11 +225,14 @@ def _set_ui_auth_session_data_txn(
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
):
# Get the current value.
result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
result = cast(
Dict[str, Any],
self.db_pool.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
),
)

# Update it and add it back to the database.
Expand Down