Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 0147b3d

Browse files
authored
Add missing type hints to synapse.logging.context (#11556)
1 parent 2519bea commit 0147b3d

File tree

13 files changed

+215
-122
lines changed

13 files changed

+215
-122
lines changed

changelog.d/11556.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add missing type hints to `synapse.logging.context`.

mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ disallow_untyped_defs = True
167167
[mypy-synapse.http.server]
168168
disallow_untyped_defs = True
169169

170+
[mypy-synapse.logging.context]
171+
disallow_untyped_defs = True
172+
170173
[mypy-synapse.metrics.*]
171174
disallow_untyped_defs = True
172175

stubs/txredisapi.pyi

+5-4
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,21 @@
1717
from typing import Any, List, Optional, Type, Union
1818

1919
from twisted.internet import protocol
20+
from twisted.internet.defer import Deferred
2021

2122
class RedisProtocol(protocol.Protocol):
2223
def publish(self, channel: str, message: bytes): ...
23-
async def ping(self) -> None: ...
24-
async def set(
24+
def ping(self) -> "Deferred[None]": ...
25+
def set(
2526
self,
2627
key: str,
2728
value: Any,
2829
expire: Optional[int] = None,
2930
pexpire: Optional[int] = None,
3031
only_if_not_exists: bool = False,
3132
only_if_exists: bool = False,
32-
) -> None: ...
33-
async def get(self, key: str) -> Any: ...
33+
) -> "Deferred[None]": ...
34+
def get(self, key: str) -> "Deferred[Any]": ...
3435

3536
class SubscriberProtocol(RedisProtocol):
3637
def __init__(self, *args, **kwargs): ...

synapse/federation/federation_server.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
from prometheus_client import Counter, Gauge, Histogram
3232

33-
from twisted.internet import defer
3433
from twisted.internet.abstract import isIPAddress
3534
from twisted.python import failure
3635

@@ -67,7 +66,7 @@
6766
from synapse.storage.databases.main.lock import Lock
6867
from synapse.types import JsonDict, get_domain_from_id
6968
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
70-
from synapse.util.async_helpers import Linearizer, concurrently_execute
69+
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
7170
from synapse.util.caches.response_cache import ResponseCache
7271
from synapse.util.stringutils import parse_server_name
7372

@@ -360,13 +359,13 @@ async def _handle_incoming_transaction(
360359
# want to block things like to device messages from reaching clients
361360
# behind the potentially expensive handling of PDUs.
362361
pdu_results, _ = await make_deferred_yieldable(
363-
defer.gatherResults(
364-
[
362+
gather_results(
363+
(
365364
run_in_background(
366365
self._handle_pdus_in_txn, origin, transaction, request_time
367366
),
368367
run_in_background(self._handle_edus_in_txn, origin, transaction),
369-
],
368+
),
370369
consumeErrors=True,
371370
).addErrback(unwrapFirstError)
372371
)

synapse/handlers/federation.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -360,31 +360,34 @@ async def try_backfill(domains: List[str]) -> bool:
360360

361361
logger.debug("calling resolve_state_groups in _maybe_backfill")
362362
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
363-
states = await make_deferred_yieldable(
363+
states_list = await make_deferred_yieldable(
364364
defer.gatherResults(
365365
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
366366
)
367367
)
368368

369-
# dict[str, dict[tuple, str]], a map from event_id to state map of
370-
# event_ids.
371-
states = dict(zip(event_ids, [s.state for s in states]))
369+
# A map from event_id to state map of event_ids.
370+
state_ids: Dict[str, StateMap[str]] = dict(
371+
zip(event_ids, [s.state for s in states_list])
372+
)
372373

373374
state_map = await self.store.get_events(
374-
[e_id for ids in states.values() for e_id in ids.values()],
375+
[e_id for ids in state_ids.values() for e_id in ids.values()],
375376
get_prev_content=False,
376377
)
377-
states = {
378+
379+
# A map from event_id to state map of events.
380+
state_events: Dict[str, StateMap[EventBase]] = {
378381
key: {
379382
k: state_map[e_id]
380383
for k, e_id in state_dict.items()
381384
if e_id in state_map
382385
}
383-
for key, state_dict in states.items()
386+
for key, state_dict in state_ids.items()
384387
}
385388

386389
for e_id in event_ids:
387-
likely_extremeties_domains = get_domains_from_state(states[e_id])
390+
likely_extremeties_domains = get_domains_from_state(state_events[e_id])
388391

389392
success = await try_backfill(
390393
[

synapse/handlers/initial_sync.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,27 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import TYPE_CHECKING, List, Optional, Tuple
17-
18-
from twisted.internet import defer
16+
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
1917

2018
from synapse.api.constants import EduTypes, EventTypes, Membership
2119
from synapse.api.errors import SynapseError
20+
from synapse.events import EventBase
2221
from synapse.events.validator import EventValidator
2322
from synapse.handlers.presence import format_user_presence_state
2423
from synapse.handlers.receipts import ReceiptEventSource
2524
from synapse.logging.context import make_deferred_yieldable, run_in_background
2625
from synapse.storage.roommember import RoomsForUser
2726
from synapse.streams.config import PaginationConfig
28-
from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
27+
from synapse.types import (
28+
JsonDict,
29+
Requester,
30+
RoomStreamToken,
31+
StateMap,
32+
StreamToken,
33+
UserID,
34+
)
2935
from synapse.util import unwrapFirstError
30-
from synapse.util.async_helpers import concurrently_execute
36+
from synapse.util.async_helpers import concurrently_execute, gather_results
3137
from synapse.util.caches.response_cache import ResponseCache
3238
from synapse.visibility import filter_events_for_client
3339

@@ -190,22 +196,21 @@ async def handle_room(event: RoomsForUser) -> None:
190196
)
191197
deferred_room_state = run_in_background(
192198
self.state_store.get_state_for_events, [event.event_id]
193-
)
194-
deferred_room_state.addCallback(
195-
lambda states: states[event.event_id]
199+
).addCallback(
200+
lambda states: cast(StateMap[EventBase], states[event.event_id])
196201
)
197202

198203
(messages, token), current_state = await make_deferred_yieldable(
199-
defer.gatherResults(
200-
[
204+
gather_results(
205+
(
201206
run_in_background(
202207
self.store.get_recent_events_for_room,
203208
event.room_id,
204209
limit=limit,
205210
end_token=room_end_token,
206211
),
207212
deferred_room_state,
208-
]
213+
)
209214
)
210215
).addErrback(unwrapFirstError)
211216

@@ -454,8 +459,8 @@ async def get_receipts() -> List[JsonDict]:
454459
return receipts
455460

456461
presence, receipts, (messages, token) = await make_deferred_yieldable(
457-
defer.gatherResults(
458-
[
462+
gather_results(
463+
(
459464
run_in_background(get_presence),
460465
run_in_background(get_receipts),
461466
run_in_background(
@@ -464,7 +469,7 @@ async def get_receipts() -> List[JsonDict]:
464469
limit=limit,
465470
end_token=now_token.room_key,
466471
),
467-
],
472+
),
468473
consumeErrors=True,
469474
).addErrback(unwrapFirstError)
470475
)

synapse/handlers/message.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from canonicaljson import encode_canonical_json
2323

24-
from twisted.internet import defer
2524
from twisted.internet.interfaces import IDelayedCall
2625

2726
from synapse import event_auth
@@ -57,7 +56,7 @@
5756
from synapse.storage.state import StateFilter
5857
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
5958
from synapse.util import json_decoder, json_encoder, log_failure
60-
from synapse.util.async_helpers import Linearizer, unwrapFirstError
59+
from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
6160
from synapse.util.caches.expiringcache import ExpiringCache
6261
from synapse.util.metrics import measure_func
6362
from synapse.visibility import filter_events_for_client
@@ -1168,9 +1167,9 @@ async def handle_new_client_event(
11681167

11691168
# We now persist the event (and update the cache in parallel, since we
11701169
# don't want to block on it).
1171-
result = await make_deferred_yieldable(
1172-
defer.gatherResults(
1173-
[
1170+
result, _ = await make_deferred_yieldable(
1171+
gather_results(
1172+
(
11741173
run_in_background(
11751174
self._persist_event,
11761175
requester=requester,
@@ -1182,12 +1181,12 @@ async def handle_new_client_event(
11821181
run_in_background(
11831182
self.cache_joined_hosts_for_event, event, context
11841183
).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
1185-
],
1184+
),
11861185
consumeErrors=True,
11871186
)
11881187
).addErrback(unwrapFirstError)
11891188

1190-
return result[0]
1189+
return result
11911190

11921191
async def _persist_event(
11931192
self,

synapse/http/federation/matrix_federation_agent.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from twisted.internet import defer
2626
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
2727
from twisted.internet.interfaces import (
28+
IProtocol,
2829
IProtocolFactory,
2930
IReactorCore,
3031
IStreamClientEndpoint,
@@ -309,12 +310,14 @@ def __init__(
309310

310311
self._srv_resolver = srv_resolver
311312

312-
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
313+
def connect(
314+
self, protocol_factory: IProtocolFactory
315+
) -> "defer.Deferred[IProtocol]":
313316
"""Implements IStreamClientEndpoint interface"""
314317

315318
return run_in_background(self._do_connect, protocol_factory)
316319

317-
async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
320+
async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
318321
first_exception = None
319322

320323
server_list = await self._resolve_server()

0 commit comments

Comments
 (0)