13
13
# limitations under the License.
14
14
15
15
import logging
16
- from typing import TYPE_CHECKING , Dict , List , Optional , Tuple
16
+ from typing import TYPE_CHECKING , Dict , List , Optional , Sequence , Set , Tuple , cast
17
17
18
18
import attr
19
19
@@ -240,12 +240,14 @@ def __init__(
240
240
241
241
################################################################################
242
242
243
- async def _background_reindex_fields_sender (self , progress , batch_size ):
243
+ async def _background_reindex_fields_sender (
244
+ self , progress : JsonDict , batch_size : int
245
+ ) -> int :
244
246
target_min_stream_id = progress ["target_min_stream_id_inclusive" ]
245
247
max_stream_id = progress ["max_stream_id_exclusive" ]
246
248
rows_inserted = progress .get ("rows_inserted" , 0 )
247
249
248
- def reindex_txn (txn ) :
250
+ def reindex_txn (txn : LoggingTransaction ) -> int :
249
251
sql = (
250
252
"SELECT stream_ordering, event_id, json FROM events"
251
253
" INNER JOIN event_json USING (event_id)"
@@ -307,12 +309,14 @@ def reindex_txn(txn):
307
309
308
310
return result
309
311
310
- async def _background_reindex_origin_server_ts (self , progress , batch_size ):
312
+ async def _background_reindex_origin_server_ts (
313
+ self , progress : JsonDict , batch_size : int
314
+ ) -> int :
311
315
target_min_stream_id = progress ["target_min_stream_id_inclusive" ]
312
316
max_stream_id = progress ["max_stream_id_exclusive" ]
313
317
rows_inserted = progress .get ("rows_inserted" , 0 )
314
318
315
- def reindex_search_txn (txn ) :
319
+ def reindex_search_txn (txn : LoggingTransaction ) -> int :
316
320
sql = (
317
321
"SELECT stream_ordering, event_id FROM events"
318
322
" WHERE ? <= stream_ordering AND stream_ordering < ?"
@@ -381,7 +385,9 @@ def reindex_search_txn(txn):
381
385
382
386
return result
383
387
384
- async def _cleanup_extremities_bg_update (self , progress , batch_size ):
388
+ async def _cleanup_extremities_bg_update (
389
+ self , progress : JsonDict , batch_size : int
390
+ ) -> int :
385
391
"""Background update to clean out extremities that should have been
386
392
deleted previously.
387
393
@@ -402,12 +408,12 @@ async def _cleanup_extremities_bg_update(self, progress, batch_size):
402
408
# have any descendants, but if they do then we should delete those
403
409
# extremities.
404
410
405
- def _cleanup_extremities_bg_update_txn (txn ) :
411
+ def _cleanup_extremities_bg_update_txn (txn : LoggingTransaction ) -> int :
406
412
# The set of extremity event IDs that we're checking this round
407
413
original_set = set ()
408
414
409
- # A dict[str, set [str]] of event ID to their prev events.
410
- graph = {}
415
+ # A dict[str, Set [str]] of event ID to their prev events.
416
+ graph : Dict [ str , Set [ str ]] = {}
411
417
412
418
# The set of descendants of the original set that are not rejected
413
419
# nor soft-failed. Ancestors of these events should be removed
@@ -536,7 +542,7 @@ def _cleanup_extremities_bg_update_txn(txn):
536
542
room_ids = {row ["room_id" ] for row in rows }
537
543
for room_id in room_ids :
538
544
txn .call_after (
539
- self .get_latest_event_ids_in_room .invalidate , (room_id ,)
545
+ self .get_latest_event_ids_in_room .invalidate , (room_id ,) # type: ignore[attr-defined]
540
546
)
541
547
542
548
self .db_pool .simple_delete_many_txn (
@@ -558,7 +564,7 @@ def _cleanup_extremities_bg_update_txn(txn):
558
564
_BackgroundUpdates .DELETE_SOFT_FAILED_EXTREMITIES
559
565
)
560
566
561
- def _drop_table_txn (txn ) :
567
+ def _drop_table_txn (txn : LoggingTransaction ) -> None :
562
568
txn .execute ("DROP TABLE _extremities_to_check" )
563
569
564
570
await self .db_pool .runInteraction (
@@ -567,11 +573,11 @@ def _drop_table_txn(txn):
567
573
568
574
return num_handled
569
575
570
- async def _redactions_received_ts (self , progress , batch_size ) :
576
+ async def _redactions_received_ts (self , progress : JsonDict , batch_size : int ) -> int :
571
577
"""Handles filling out the `received_ts` column in redactions."""
572
578
last_event_id = progress .get ("last_event_id" , "" )
573
579
574
- def _redactions_received_ts_txn (txn ) :
580
+ def _redactions_received_ts_txn (txn : LoggingTransaction ) -> int :
575
581
# Fetch the set of event IDs that we want to update
576
582
sql = """
577
583
SELECT event_id FROM redactions
@@ -622,10 +628,12 @@ def _redactions_received_ts_txn(txn):
622
628
623
629
return count
624
630
625
- async def _event_fix_redactions_bytes (self , progress , batch_size ):
631
+ async def _event_fix_redactions_bytes (
632
+ self , progress : JsonDict , batch_size : int
633
+ ) -> int :
626
634
"""Undoes hex encoded censored redacted event JSON."""
627
635
628
- def _event_fix_redactions_bytes_txn (txn ) :
636
+ def _event_fix_redactions_bytes_txn (txn : LoggingTransaction ) -> None :
629
637
# This update is quite fast due to new index.
630
638
txn .execute (
631
639
"""
@@ -650,11 +658,11 @@ def _event_fix_redactions_bytes_txn(txn):
650
658
651
659
return 1
652
660
653
- async def _event_store_labels (self , progress , batch_size ) :
661
+ async def _event_store_labels (self , progress : JsonDict , batch_size : int ) -> int :
654
662
"""Background update handler which will store labels for existing events."""
655
663
last_event_id = progress .get ("last_event_id" , "" )
656
664
657
- def _event_store_labels_txn (txn ) :
665
+ def _event_store_labels_txn (txn : LoggingTransaction ) -> int :
658
666
txn .execute (
659
667
"""
660
668
SELECT event_id, json FROM event_json
@@ -754,7 +762,10 @@ def get_rejected_events(
754
762
),
755
763
)
756
764
757
- return [(row [0 ], row [1 ], db_to_json (row [2 ]), row [3 ], row [4 ]) for row in txn ] # type: ignore
765
+ return cast (
766
+ List [Tuple [str , str , JsonDict , bool , bool ]],
767
+ [(row [0 ], row [1 ], db_to_json (row [2 ]), row [3 ], row [4 ]) for row in txn ],
768
+ )
758
769
759
770
results = await self .db_pool .runInteraction (
760
771
desc = "_rejected_events_metadata_get" , func = get_rejected_events
@@ -912,7 +923,7 @@ async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
912
923
913
924
def _calculate_chain_cover_txn (
914
925
self ,
915
- txn : Cursor ,
926
+ txn : LoggingTransaction ,
916
927
last_room_id : str ,
917
928
last_depth : int ,
918
929
last_stream : int ,
@@ -1023,10 +1034,10 @@ def _calculate_chain_cover_txn(
1023
1034
PersistEventsStore ._add_chain_cover_index (
1024
1035
txn ,
1025
1036
self .db_pool ,
1026
- self .event_chain_id_gen ,
1037
+ self .event_chain_id_gen , # type: ignore[attr-defined]
1027
1038
event_to_room_id ,
1028
1039
event_to_types ,
1029
- event_to_auth_chain ,
1040
+ cast ( Dict [ str , Sequence [ str ]], event_to_auth_chain ) ,
1030
1041
)
1031
1042
1032
1043
return _CalculateChainCover (
@@ -1046,7 +1057,7 @@ async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> in
1046
1057
"""
1047
1058
current_event_id = progress .get ("current_event_id" , "" )
1048
1059
1049
- def purged_chain_cover_txn (txn ) -> int :
1060
+ def purged_chain_cover_txn (txn : LoggingTransaction ) -> int :
1050
1061
# The event ID from events will be null if the chain ID / sequence
1051
1062
# number points to a purged event.
1052
1063
sql = """
@@ -1181,14 +1192,14 @@ def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
1181
1192
# Iterate the parent IDs and invalidate caches.
1182
1193
for parent_id in {r [1 ] for r in relations_to_insert }:
1183
1194
cache_tuple = (parent_id ,)
1184
- self ._invalidate_cache_and_stream (
1185
- txn , self .get_relations_for_event , cache_tuple
1195
+ self ._invalidate_cache_and_stream ( # type: ignore[attr-defined]
1196
+ txn , self .get_relations_for_event , cache_tuple # type: ignore[attr-defined]
1186
1197
)
1187
- self ._invalidate_cache_and_stream (
1188
- txn , self .get_aggregation_groups_for_event , cache_tuple
1198
+ self ._invalidate_cache_and_stream ( # type: ignore[attr-defined]
1199
+ txn , self .get_aggregation_groups_for_event , cache_tuple # type: ignore[attr-defined]
1189
1200
)
1190
- self ._invalidate_cache_and_stream (
1191
- txn , self .get_thread_summary , cache_tuple
1201
+ self ._invalidate_cache_and_stream ( # type: ignore[attr-defined]
1202
+ txn , self .get_thread_summary , cache_tuple # type: ignore[attr-defined]
1192
1203
)
1193
1204
1194
1205
if results :
@@ -1220,7 +1231,7 @@ async def _background_populate_stream_ordering2(
1220
1231
"""
1221
1232
batch_size = max (batch_size , 1 )
1222
1233
1223
- def process (txn : Cursor ) -> int :
1234
+ def process (txn : LoggingTransaction ) -> int :
1224
1235
last_stream = progress .get ("last_stream" , - (1 << 31 ))
1225
1236
txn .execute (
1226
1237
"""
0 commit comments