Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add typing to synapse.federation.sender #6871

Merged
merged 10 commits into from
Feb 7, 2020
1 change: 1 addition & 0 deletions changelog.d/6871.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add typing to `synapse.federation.sender` and port to async/await.
7 changes: 6 additions & 1 deletion synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,12 @@ async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
async def _process_edu(edu_dict):
received_edus_counter.inc()

edu = Edu(**edu_dict)
edu = Edu(
origin=origin,
destination=self.server_name,
edu_type=edu_dict["edu_type"],
content=edu_dict["content"],
)
await self.registry.on_edu(edu.edu_type, origin, edu.content)

await concurrently_execute(
Expand Down
99 changes: 48 additions & 51 deletions synapse/federation/sender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
from typing import Dict, Hashable, Iterable, List, Optional, Set

from six import itervalues

Expand All @@ -23,6 +24,7 @@

import synapse
import synapse.metrics
from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager
from synapse.federation.units import Edu
Expand All @@ -39,6 +41,8 @@
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.metrics import Measure, measure_func

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -68,7 +72,7 @@ def __init__(self, hs: "synapse.server.HomeServer"):
self._transaction_manager = TransactionManager(hs)

# map from destination to PerDestinationQueue
self._per_destination_queues = {} # type: dict[str, PerDestinationQueue]
self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue]

LaterGauge(
"synapse_federation_transaction_queue_pending_destinations",
Expand All @@ -84,7 +88,7 @@ def __init__(self, hs: "synapse.server.HomeServer"):
# Map of user_id -> UserPresenceState for all the pending presence
# to be sent out by user_id. Entries here get processed and put in
# pending_presence_by_dest
self.pending_presence = {}
self.pending_presence = {} # type: Dict[str, UserPresenceState]

LaterGauge(
"synapse_federation_transaction_queue_pending_pdus",
Expand Down Expand Up @@ -116,28 +120,25 @@ def __init__(self, hs: "synapse.server.HomeServer"):
# and that there is a pending call to _flush_rrs_for_room in the system.
self._queues_awaiting_rr_flush_by_room = (
{}
) # type: dict[str, set[PerDestinationQueue]]
) # type: Dict[str, Set[PerDestinationQueue]]

self._rr_txn_interval_per_room_ms = (
1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second
1000.0 / hs.config.federation_rr_transactions_per_room_per_second
)

def _get_per_destination_queue(self, destination):
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination

Args:
destination (str): server_name of remote server

Returns:
PerDestinationQueue
destination: server_name of remote server
"""
queue = self._per_destination_queues.get(destination)
if not queue:
queue = PerDestinationQueue(self.hs, self._transaction_manager, destination)
self._per_destination_queues[destination] = queue
return queue

def notify_new_events(self, current_id):
def notify_new_events(self, current_id: int) -> None:
"""This gets called when we have some new events we might want to
send out to other servers.
"""
Expand All @@ -151,13 +152,12 @@ def notify_new_events(self, current_id):
"process_event_queue_for_federation", self._process_event_queue_loop
)

@defer.inlineCallbacks
def _process_event_queue_loop(self):
async def _process_event_queue_loop(self) -> None:
try:
self._is_processing = True
while True:
last_token = yield self.store.get_federation_out_pos("events")
next_token, events = yield self.store.get_all_new_events_stream(
last_token = await self.store.get_federation_out_pos("events")
next_token, events = await self.store.get_all_new_events_stream(
last_token, self._last_poked_id, limit=100
)

Expand All @@ -166,8 +166,7 @@ def _process_event_queue_loop(self):
if not events and next_token >= self._last_poked_id:
break

@defer.inlineCallbacks
def handle_event(event):
async def handle_event(event: EventBase) -> None:
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.sender)
Expand All @@ -184,7 +183,7 @@ def handle_event(event):
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = yield self.state.get_hosts_in_room_at_events(
destinations = await self.state.get_hosts_in_room_at_events(
event.room_id, event_ids=event.prev_event_ids()
)
except Exception:
Expand All @@ -206,17 +205,16 @@ def handle_event(event):

self._send_pdu(event, destinations)

@defer.inlineCallbacks
def handle_room_events(events):
async def handle_room_events(events: Iterable[EventBase]) -> None:
with Measure(self.clock, "handle_room_events"):
for event in events:
yield handle_event(event)
await handle_event(event)

events_by_room = {}
events_by_room = {} # type: Dict[str, List[EventBase]]
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)

yield make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(handle_room_events, evs)
Expand All @@ -226,11 +224,11 @@ def handle_room_events(events):
)
)

yield self.store.update_federation_out_pos("events", next_token)
await self.store.update_federation_out_pos("events", next_token)

if events:
now = self.clock.time_msec()
ts = yield self.store.get_received_ts(events[-1].event_id)
ts = await self.store.get_received_ts(events[-1].event_id)

synapse.metrics.event_processing_lag.labels(
"federation_sender"
Expand All @@ -254,7 +252,7 @@ def handle_room_events(events):
finally:
self._is_processing = False

def _send_pdu(self, pdu, destinations):
def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
Expand All @@ -276,11 +274,11 @@ def _send_pdu(self, pdu, destinations):
self._get_per_destination_queue(destination).send_pdu(pdu, order)

@defer.inlineCallbacks
def send_read_receipt(self, receipt):
def send_read_receipt(self, receipt: ReadReceipt):
"""Send a RR to any other servers in the room

Args:
receipt (synapse.types.ReadReceipt): receipt to be sent
receipt: receipt to be sent
"""

# Some background on the rate-limiting going on here.
Expand Down Expand Up @@ -343,7 +341,7 @@ def send_read_receipt(self, receipt):
else:
queue.flush_read_receipts_for_room(room_id)

def _schedule_rr_flush_for_room(self, room_id, n_domains):
def _schedule_rr_flush_for_room(self, room_id: str, n_domains: int) -> None:
# that is going to cause approximately len(domains) transactions, so now back
# off for that multiplied by RR_TXN_INTERVAL_PER_ROOM
backoff_ms = self._rr_txn_interval_per_room_ms * n_domains
Expand All @@ -352,7 +350,7 @@ def _schedule_rr_flush_for_room(self, room_id, n_domains):
self.clock.call_later(backoff_ms, self._flush_rrs_for_room, room_id)
self._queues_awaiting_rr_flush_by_room[room_id] = set()

def _flush_rrs_for_room(self, room_id):
def _flush_rrs_for_room(self, room_id: str) -> None:
queues = self._queues_awaiting_rr_flush_by_room.pop(room_id)
logger.debug("Flushing RRs in %s to %s", room_id, queues)

Expand All @@ -368,14 +366,11 @@ def _flush_rrs_for_room(self, room_id):

@preserve_fn # the caller should not yield on this
@defer.inlineCallbacks
def send_presence(self, states):
def send_presence(self, states: List[UserPresenceState]):
"""Send the new presence states to the appropriate destinations.

This actually queues up the presence states ready for sending and
triggers a background task to process them and send out the transactions.

Args:
states (list(UserPresenceState))
"""
if not self.hs.config.use_presence:
# No-op if presence is disabled.
Expand Down Expand Up @@ -412,11 +407,10 @@ def send_presence(self, states):
finally:
self._processing_pending_presence = False

def send_presence_to_destinations(self, states, destinations):
def send_presence_to_destinations(
self, states: List[UserPresenceState], destinations: List[str]
) -> None:
"""Send the given presence states to the given destinations.

Args:
states (list[UserPresenceState])
destinations (list[str])
"""

Expand All @@ -431,12 +425,9 @@ def send_presence_to_destinations(self, states, destinations):

@measure_func("txnqueue._process_presence")
@defer.inlineCallbacks
def _process_presence_inner(self, states):
def _process_presence_inner(self, states: List[UserPresenceState]):
"""Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination

Args:
states (list(UserPresenceState))
"""
hosts_and_states = yield get_interested_remotes(self.store, states, self.state)

Expand All @@ -446,14 +437,20 @@ def _process_presence_inner(self, states):
continue
self._get_per_destination_queue(destination).send_presence(states)

def build_and_send_edu(self, destination, edu_type, content, key=None):
def build_and_send_edu(
self,
destination: str,
edu_type: str,
content: dict,
key: Optional[Hashable] = None,
):
"""Construct an Edu object, and queue it for sending

Args:
destination (str): name of server to send to
edu_type (str): type of EDU to send
content (dict): content of EDU
key (Any|None): clobbering key for this edu
destination: name of server to send to
edu_type: type of EDU to send
content: content of EDU
key: clobbering key for this edu
"""
if destination == self.server_name:
logger.info("Not sending EDU to ourselves")
Expand All @@ -468,20 +465,20 @@ def build_and_send_edu(self, destination, edu_type, content, key=None):

self.send_edu(edu, key)

def send_edu(self, edu, key):
def send_edu(self, edu: Edu, key: Optional[Hashable]):
"""Queue an EDU for sending

Args:
edu (Edu): edu to send
key (Any|None): clobbering key for this edu
edu: edu to send
key: clobbering key for this edu
"""
queue = self._get_per_destination_queue(edu.destination)
if key:
queue.send_keyed_edu(edu, key)
else:
queue.send_edu(edu)

def send_device_messages(self, destination):
def send_device_messages(self, destination: str):
if destination == self.server_name:
logger.warning("Not sending device update to ourselves")
return
Expand All @@ -501,5 +498,5 @@ def wake_destination(self, destination: str):

self._get_per_destination_queue(destination).attempt_new_transaction()

def get_current_token(self):
def get_current_token(self) -> int:
return 0
Loading