@@ -437,6 +437,7 @@ def _get_unread_counts_by_pos_txn(
437
437
"""
438
438
439
439
counts = NotifCounts ()
440
+ thread_counts = {}
440
441
441
442
# First we pull the counts from the summary table.
442
443
#
@@ -453,7 +454,7 @@ def _get_unread_counts_by_pos_txn(
453
454
# receipt).
454
455
txn .execute (
455
456
"""
456
- SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
457
+ SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id
457
458
FROM event_push_summary
458
459
WHERE room_id = ? AND user_id = ?
459
460
AND (
@@ -463,42 +464,70 @@ def _get_unread_counts_by_pos_txn(
463
464
""" ,
464
465
(room_id , user_id , receipt_stream_ordering , receipt_stream_ordering ),
465
466
)
466
- row = txn .fetchone ()
467
+ max_summary_stream_ordering = 0
468
+ for summary_stream_ordering , notif_count , unread_count , thread_id in txn :
469
+ if thread_id == "main" :
470
+ counts = NotifCounts (
471
+ notify_count = notif_count , unread_count = unread_count
472
+ )
473
+ # TODO Delete zeroed out threads completely from the database.
474
+ elif notif_count or unread_count :
475
+ thread_counts [thread_id ] = NotifCounts (
476
+ notify_count = notif_count , unread_count = unread_count
477
+ )
467
478
468
- summary_stream_ordering = 0
469
- if row :
470
- summary_stream_ordering = row [0 ]
471
- counts .notify_count += row [1 ]
472
- counts .unread_count += row [2 ]
479
+ # XXX All threads should have the same stream ordering?
480
+ max_summary_stream_ordering = max (
481
+ summary_stream_ordering , max_summary_stream_ordering
482
+ )
473
483
474
484
# Next we need to count highlights, which aren't summarised
475
485
sql = """
476
- SELECT COUNT(*) FROM event_push_actions
486
+ SELECT COUNT(*), thread_id FROM event_push_actions
477
487
WHERE user_id = ?
478
488
AND room_id = ?
479
489
AND stream_ordering > ?
480
490
AND highlight = 1
491
+ GROUP BY thread_id
481
492
"""
482
493
txn .execute (sql , (user_id , room_id , receipt_stream_ordering ))
483
- row = txn .fetchone ()
484
- if row :
485
- counts .highlight_count += row [0 ]
494
+ for highlight_count , thread_id in txn :
495
+ if thread_id == "main" :
496
+ counts .highlight_count += highlight_count
497
+ elif highlight_count :
498
+ if thread_id in thread_counts :
499
+ thread_counts [thread_id ].highlight_count += highlight_count
500
+ else :
501
+ thread_counts [thread_id ] = NotifCounts (
502
+ notify_count = 0 , unread_count = 0 , highlight_count = highlight_count
503
+ )
486
504
487
505
# Finally we need to count push actions that aren't included in the
488
506
# summary returned above. This might be due to recent events that haven't
489
507
# been summarised yet or the summary is out of date due to a recent read
490
508
# receipt.
491
509
start_unread_stream_ordering = max (
492
- receipt_stream_ordering , summary_stream_ordering
510
+ receipt_stream_ordering , max_summary_stream_ordering
493
511
)
494
- notify_count , unread_count = self ._get_notif_unread_count_for_user_room (
512
+ unread_counts = self ._get_notif_unread_count_for_user_room (
495
513
txn , room_id , user_id , start_unread_stream_ordering
496
514
)
497
515
498
- counts .notify_count += notify_count
499
- counts .unread_count += unread_count
516
+ for notif_count , unread_count , thread_id in unread_counts :
517
+ if thread_id == "main" :
518
+ counts .notify_count += notif_count
519
+ counts .unread_count += unread_count
520
+ elif thread_id in thread_counts :
521
+ thread_counts [thread_id ].notify_count += notif_count
522
+ thread_counts [thread_id ].unread_count += unread_count
523
+ else :
524
+ thread_counts [thread_id ] = NotifCounts (
525
+ notify_count = notif_count ,
526
+ unread_count = unread_count ,
527
+ highlight_count = 0 ,
528
+ )
500
529
501
- return RoomNotifCounts (counts , {} )
530
+ return RoomNotifCounts (counts , thread_counts )
502
531
503
532
def _get_notif_unread_count_for_user_room (
504
533
self ,
@@ -507,7 +536,7 @@ def _get_notif_unread_count_for_user_room(
507
536
user_id : str ,
508
537
stream_ordering : int ,
509
538
max_stream_ordering : Optional [int ] = None ,
510
- ) -> Tuple [int , int ]:
539
+ ) -> List [ Tuple [int , int , str ] ]:
511
540
"""Returns the notify and unread counts from `event_push_actions` for
512
541
the given user/room in the given range.
513
542
@@ -523,13 +552,14 @@ def _get_notif_unread_count_for_user_room(
523
552
If this is not given, then no maximum is applied.
524
553
525
554
Return:
526
- A tuple of the notif count and unread count in the given range.
555
+ A tuple of the notif count and unread count in the given range for
556
+ each thread.
527
557
"""
528
558
529
559
# If there have been no events in the room since the stream ordering,
530
560
# there can't be any push actions either.
531
561
if not self ._events_stream_cache .has_entity_changed (room_id , stream_ordering ):
532
- return 0 , 0
562
+ return []
533
563
534
564
clause = ""
535
565
args = [user_id , room_id , stream_ordering ]
@@ -540,26 +570,23 @@ def _get_notif_unread_count_for_user_room(
540
570
# If the max stream ordering is less than the min stream ordering,
541
571
# then obviously there are zero push actions in that range.
542
572
if max_stream_ordering <= stream_ordering :
543
- return 0 , 0
573
+ return []
544
574
545
575
sql = f"""
546
576
SELECT
547
577
COUNT(CASE WHEN notif = 1 THEN 1 END),
548
- COUNT(CASE WHEN unread = 1 THEN 1 END)
549
- FROM event_push_actions ea
550
- WHERE user_id = ?
578
+ COUNT(CASE WHEN unread = 1 THEN 1 END),
579
+ thread_id
580
+ FROM event_push_actions ea
581
+ WHERE user_id = ?
551
582
AND room_id = ?
552
583
AND ea.stream_ordering > ?
553
584
{ clause }
585
+ GROUP BY thread_id
554
586
"""
555
587
556
588
txn .execute (sql , args )
557
- row = txn .fetchone ()
558
-
559
- if row :
560
- return cast (Tuple [int , int ], row )
561
-
562
- return 0 , 0
589
+ return cast (List [Tuple [int , int , str ]], txn .fetchall ())
563
590
564
591
async def get_push_action_users_in_range (
565
592
self , min_stream_ordering : int , max_stream_ordering : int
@@ -1103,26 +1130,34 @@ def _handle_new_receipts_for_notifs_txn(self, txn: LoggingTransaction) -> bool:
1103
1130
1104
1131
# Fetch the notification counts between the stream ordering of the
1105
1132
# latest receipt and what was previously summarised.
1106
- notif_count , unread_count = self ._get_notif_unread_count_for_user_room (
1133
+ unread_counts = self ._get_notif_unread_count_for_user_room (
1107
1134
txn , room_id , user_id , stream_ordering , old_rotate_stream_ordering
1108
1135
)
1109
1136
1110
- # Replace the previous summary with the new counts.
1111
- #
1112
- # TODO(threads): Upsert per-thread instead of setting them all to main.
1113
- self .db_pool .simple_upsert_txn (
1137
+ # First mark the summary for all threads in the room as cleared.
1138
+ self .db_pool .simple_update_txn (
1114
1139
txn ,
1115
1140
table = "event_push_summary" ,
1116
- keyvalues = {"room_id " : room_id , "user_id " : user_id },
1117
- values = {
1118
- "notif_count" : notif_count ,
1119
- "unread_count" : unread_count ,
1141
+ keyvalues = {"user_id " : user_id , "room_id " : room_id },
1142
+ updatevalues = {
1143
+ "notif_count" : 0 ,
1144
+ "unread_count" : 0 ,
1120
1145
"stream_ordering" : old_rotate_stream_ordering ,
1121
1146
"last_receipt_stream_ordering" : stream_ordering ,
1122
- "thread_id" : "main" ,
1123
1147
},
1124
1148
)
1125
1149
1150
+ # Then any updated threads get their notification count and unread
1151
+ # count updated.
1152
+ self .db_pool .simple_upsert_many_txn (
1153
+ txn ,
1154
+ table = "event_push_summary" ,
1155
+ key_names = ("room_id" , "user_id" , "thread_id" ),
1156
+ key_values = [(room_id , user_id , row [2 ]) for row in unread_counts ],
1157
+ value_names = ("notif_count" , "unread_count" ),
1158
+ value_values = [(row [0 ], row [1 ]) for row in unread_counts ],
1159
+ )
1160
+
1126
1161
# We always update `event_push_summary_last_receipt_stream_id` to
1127
1162
# ensure that we don't rescan the same receipts for remote users.
1128
1163
@@ -1208,23 +1243,23 @@ def _rotate_notifs_before_txn(
1208
1243
1209
1244
# Calculate the new counts that should be upserted into event_push_summary
1210
1245
sql = """
1211
- SELECT user_id, room_id,
1246
+ SELECT user_id, room_id, thread_id,
1212
1247
coalesce(old.%s, 0) + upd.cnt,
1213
1248
upd.stream_ordering
1214
1249
FROM (
1215
- SELECT user_id, room_id, count(*) as cnt,
1250
+ SELECT user_id, room_id, thread_id, count(*) as cnt,
1216
1251
max(ea.stream_ordering) as stream_ordering
1217
1252
FROM event_push_actions AS ea
1218
- LEFT JOIN event_push_summary AS old USING (user_id, room_id)
1253
+ LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id )
1219
1254
WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
1220
1255
AND (
1221
1256
old.last_receipt_stream_ordering IS NULL
1222
1257
OR old.last_receipt_stream_ordering < ea.stream_ordering
1223
1258
)
1224
1259
AND %s = 1
1225
- GROUP BY user_id, room_id
1260
+ GROUP BY user_id, room_id, thread_id
1226
1261
) AS upd
1227
- LEFT JOIN event_push_summary AS old USING (user_id, room_id)
1262
+ LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id )
1228
1263
"""
1229
1264
1230
1265
# First get the count of unread messages.
@@ -1238,11 +1273,11 @@ def _rotate_notifs_before_txn(
1238
1273
# object because we might not have the same amount of rows in each of them. To do
1239
1274
# this, we use a dict indexed on the user ID and room ID to make it easier to
1240
1275
# populate.
1241
- summaries : Dict [Tuple [str , str ], _EventPushSummary ] = {}
1276
+ summaries : Dict [Tuple [str , str , str ], _EventPushSummary ] = {}
1242
1277
for row in txn :
1243
- summaries [(row [0 ], row [1 ])] = _EventPushSummary (
1244
- unread_count = row [2 ],
1245
- stream_ordering = row [3 ],
1278
+ summaries [(row [0 ], row [1 ], row [ 2 ] )] = _EventPushSummary (
1279
+ unread_count = row [3 ],
1280
+ stream_ordering = row [4 ],
1246
1281
notif_count = 0 ,
1247
1282
)
1248
1283
@@ -1253,34 +1288,35 @@ def _rotate_notifs_before_txn(
1253
1288
)
1254
1289
1255
1290
for row in txn :
1256
- if (row [0 ], row [1 ]) in summaries :
1257
- summaries [(row [0 ], row [1 ])].notif_count = row [2 ]
1291
+ if (row [0 ], row [1 ], row [ 2 ] ) in summaries :
1292
+ summaries [(row [0 ], row [1 ], row [ 2 ] )].notif_count = row [3 ]
1258
1293
else :
1259
1294
# Because the rules on notifying are different than the rules on marking
1260
1295
# a message unread, we might end up with messages that notify but aren't
1261
1296
# marked unread, so we might not have a summary for this (user, room)
1262
1297
# tuple to complete.
1263
- summaries [(row [0 ], row [1 ])] = _EventPushSummary (
1298
+ summaries [(row [0 ], row [1 ], row [ 2 ] )] = _EventPushSummary (
1264
1299
unread_count = 0 ,
1265
- stream_ordering = row [3 ],
1266
- notif_count = row [2 ],
1300
+ stream_ordering = row [4 ],
1301
+ notif_count = row [3 ],
1267
1302
)
1268
1303
1269
1304
logger .info ("Rotating notifications, handling %d rows" , len (summaries ))
1270
1305
1271
- # TODO(threads): Update on a per-thread basis.
1272
1306
self .db_pool .simple_upsert_many_txn (
1273
1307
txn ,
1274
1308
table = "event_push_summary" ,
1275
- key_names = ("user_id" , "room_id" ),
1276
- key_values = [(user_id , room_id ) for user_id , room_id in summaries ],
1277
- value_names = ("notif_count" , "unread_count" , "stream_ordering" , "thread_id" ),
1309
+ key_names = ("user_id" , "room_id" , "thread_id" ),
1310
+ key_values = [
1311
+ (user_id , room_id , thread_id )
1312
+ for user_id , room_id , thread_id in summaries
1313
+ ],
1314
+ value_names = ("notif_count" , "unread_count" , "stream_ordering" ),
1278
1315
value_values = [
1279
1316
(
1280
1317
summary .notif_count ,
1281
1318
summary .unread_count ,
1282
1319
summary .stream_ordering ,
1283
- "main" ,
1284
1320
)
1285
1321
for summary in summaries .values ()
1286
1322
],
0 commit comments