Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 07a3b5d

Browse files
authored
Add type hints to synapse/storage/databases/main/events_bg_updates.py (#11654)
1 parent 2c7f5e7 commit 07a3b5d

File tree

3 files changed

+44
-30
lines changed

3 files changed

+44
-30
lines changed

changelog.d/11654.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add missing type hints to storage classes.

mypy.ini

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ exclude = (?x)
2828
|synapse/storage/databases/main/cache.py
2929
|synapse/storage/databases/main/devices.py
3030
|synapse/storage/databases/main/event_federation.py
31-
|synapse/storage/databases/main/events_bg_updates.py
3231
|synapse/storage/databases/main/group_server.py
3332
|synapse/storage/databases/main/metrics.py
3433
|synapse/storage/databases/main/monthly_active_users.py
@@ -200,6 +199,9 @@ disallow_untyped_defs = True
200199
[mypy-synapse.storage.databases.main.event_push_actions]
201200
disallow_untyped_defs = True
202201

202+
[mypy-synapse.storage.databases.main.events_bg_updates]
203+
disallow_untyped_defs = True
204+
203205
[mypy-synapse.storage.databases.main.events_worker]
204206
disallow_untyped_defs = True
205207

synapse/storage/databases/main/events_bg_updates.py

+40-29
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
1717

1818
import attr
1919

@@ -240,12 +240,14 @@ def __init__(
240240

241241
################################################################################
242242

243-
async def _background_reindex_fields_sender(self, progress, batch_size):
243+
async def _background_reindex_fields_sender(
244+
self, progress: JsonDict, batch_size: int
245+
) -> int:
244246
target_min_stream_id = progress["target_min_stream_id_inclusive"]
245247
max_stream_id = progress["max_stream_id_exclusive"]
246248
rows_inserted = progress.get("rows_inserted", 0)
247249

248-
def reindex_txn(txn):
250+
def reindex_txn(txn: LoggingTransaction) -> int:
249251
sql = (
250252
"SELECT stream_ordering, event_id, json FROM events"
251253
" INNER JOIN event_json USING (event_id)"
@@ -307,12 +309,14 @@ def reindex_txn(txn):
307309

308310
return result
309311

310-
async def _background_reindex_origin_server_ts(self, progress, batch_size):
312+
async def _background_reindex_origin_server_ts(
313+
self, progress: JsonDict, batch_size: int
314+
) -> int:
311315
target_min_stream_id = progress["target_min_stream_id_inclusive"]
312316
max_stream_id = progress["max_stream_id_exclusive"]
313317
rows_inserted = progress.get("rows_inserted", 0)
314318

315-
def reindex_search_txn(txn):
319+
def reindex_search_txn(txn: LoggingTransaction) -> int:
316320
sql = (
317321
"SELECT stream_ordering, event_id FROM events"
318322
" WHERE ? <= stream_ordering AND stream_ordering < ?"
@@ -381,7 +385,9 @@ def reindex_search_txn(txn):
381385

382386
return result
383387

384-
async def _cleanup_extremities_bg_update(self, progress, batch_size):
388+
async def _cleanup_extremities_bg_update(
389+
self, progress: JsonDict, batch_size: int
390+
) -> int:
385391
"""Background update to clean out extremities that should have been
386392
deleted previously.
387393
@@ -402,12 +408,12 @@ async def _cleanup_extremities_bg_update(self, progress, batch_size):
402408
# have any descendants, but if they do then we should delete those
403409
# extremities.
404410

405-
def _cleanup_extremities_bg_update_txn(txn):
411+
def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
406412
# The set of extremity event IDs that we're checking this round
407413
original_set = set()
408414

409-
# A dict[str, set[str]] of event ID to their prev events.
410-
graph = {}
415+
# A dict[str, Set[str]] of event ID to their prev events.
416+
graph: Dict[str, Set[str]] = {}
411417

412418
# The set of descendants of the original set that are not rejected
413419
# nor soft-failed. Ancestors of these events should be removed
@@ -536,7 +542,7 @@ def _cleanup_extremities_bg_update_txn(txn):
536542
room_ids = {row["room_id"] for row in rows}
537543
for room_id in room_ids:
538544
txn.call_after(
539-
self.get_latest_event_ids_in_room.invalidate, (room_id,)
545+
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
540546
)
541547

542548
self.db_pool.simple_delete_many_txn(
@@ -558,7 +564,7 @@ def _cleanup_extremities_bg_update_txn(txn):
558564
_BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
559565
)
560566

561-
def _drop_table_txn(txn):
567+
def _drop_table_txn(txn: LoggingTransaction) -> None:
562568
txn.execute("DROP TABLE _extremities_to_check")
563569

564570
await self.db_pool.runInteraction(
@@ -567,11 +573,11 @@ def _drop_table_txn(txn):
567573

568574
return num_handled
569575

570-
async def _redactions_received_ts(self, progress, batch_size):
576+
async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int:
571577
"""Handles filling out the `received_ts` column in redactions."""
572578
last_event_id = progress.get("last_event_id", "")
573579

574-
def _redactions_received_ts_txn(txn):
580+
def _redactions_received_ts_txn(txn: LoggingTransaction) -> int:
575581
# Fetch the set of event IDs that we want to update
576582
sql = """
577583
SELECT event_id FROM redactions
@@ -622,10 +628,12 @@ def _redactions_received_ts_txn(txn):
622628

623629
return count
624630

625-
async def _event_fix_redactions_bytes(self, progress, batch_size):
631+
async def _event_fix_redactions_bytes(
632+
self, progress: JsonDict, batch_size: int
633+
) -> int:
626634
"""Undoes hex encoded censored redacted event JSON."""
627635

628-
def _event_fix_redactions_bytes_txn(txn):
636+
def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None:
629637
# This update is quite fast due to new index.
630638
txn.execute(
631639
"""
@@ -650,11 +658,11 @@ def _event_fix_redactions_bytes_txn(txn):
650658

651659
return 1
652660

653-
async def _event_store_labels(self, progress, batch_size):
661+
async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int:
654662
"""Background update handler which will store labels for existing events."""
655663
last_event_id = progress.get("last_event_id", "")
656664

657-
def _event_store_labels_txn(txn):
665+
def _event_store_labels_txn(txn: LoggingTransaction) -> int:
658666
txn.execute(
659667
"""
660668
SELECT event_id, json FROM event_json
@@ -754,7 +762,10 @@ def get_rejected_events(
754762
),
755763
)
756764

757-
return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
765+
return cast(
766+
List[Tuple[str, str, JsonDict, bool, bool]],
767+
[(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn],
768+
)
758769

759770
results = await self.db_pool.runInteraction(
760771
desc="_rejected_events_metadata_get", func=get_rejected_events
@@ -912,7 +923,7 @@ async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
912923

913924
def _calculate_chain_cover_txn(
914925
self,
915-
txn: Cursor,
926+
txn: LoggingTransaction,
916927
last_room_id: str,
917928
last_depth: int,
918929
last_stream: int,
@@ -1023,10 +1034,10 @@ def _calculate_chain_cover_txn(
10231034
PersistEventsStore._add_chain_cover_index(
10241035
txn,
10251036
self.db_pool,
1026-
self.event_chain_id_gen,
1037+
self.event_chain_id_gen, # type: ignore[attr-defined]
10271038
event_to_room_id,
10281039
event_to_types,
1029-
event_to_auth_chain,
1040+
cast(Dict[str, Sequence[str]], event_to_auth_chain),
10301041
)
10311042

10321043
return _CalculateChainCover(
@@ -1046,7 +1057,7 @@ async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> in
10461057
"""
10471058
current_event_id = progress.get("current_event_id", "")
10481059

1049-
def purged_chain_cover_txn(txn) -> int:
1060+
def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
10501061
# The event ID from events will be null if the chain ID / sequence
10511062
# number points to a purged event.
10521063
sql = """
@@ -1181,14 +1192,14 @@ def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
11811192
# Iterate the parent IDs and invalidate caches.
11821193
for parent_id in {r[1] for r in relations_to_insert}:
11831194
cache_tuple = (parent_id,)
1184-
self._invalidate_cache_and_stream(
1185-
txn, self.get_relations_for_event, cache_tuple
1195+
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
1196+
txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
11861197
)
1187-
self._invalidate_cache_and_stream(
1188-
txn, self.get_aggregation_groups_for_event, cache_tuple
1198+
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
1199+
txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined]
11891200
)
1190-
self._invalidate_cache_and_stream(
1191-
txn, self.get_thread_summary, cache_tuple
1201+
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
1202+
txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
11921203
)
11931204

11941205
if results:
@@ -1220,7 +1231,7 @@ async def _background_populate_stream_ordering2(
12201231
"""
12211232
batch_size = max(batch_size, 1)
12221233

1223-
def process(txn: Cursor) -> int:
1234+
def process(txn: LoggingTransaction) -> int:
12241235
last_stream = progress.get("last_stream", -(1 << 31))
12251236
txn.execute(
12261237
"""

0 commit comments

Comments
 (0)