Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints for most HomeServer parameters #11095

Merged
merged 19 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11095.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to most `HomeServer` parameters.
8 changes: 4 additions & 4 deletions synapse/app/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def listen_ssl(
return r


def refresh_certificate(hs):
def refresh_certificate(hs: "HomeServer"):
"""
Refresh the TLS certificates that Synapse is using by re-reading them from
disk and updating the TLS context factories to use them.
Expand Down Expand Up @@ -419,11 +419,11 @@ def run_sighup(*args, **kwargs):
atexit.register(gc.freeze)


def setup_sentry(hs):
def setup_sentry(hs: "HomeServer"):
"""Enable sentry integration, if enabled in configuration

Args:
hs (synapse.server.HomeServer)
hs
"""

if not hs.config.metrics.sentry_enabled:
Expand All @@ -449,7 +449,7 @@ def setup_sentry(hs):
scope.set_tag("worker_name", name)


def setup_sdnotify(hs):
def setup_sdnotify(hs: "HomeServer"):
"""Adds process state hooks to tell systemd what we are up to."""

# Tell systemd our state, if we're using it. This will silently fail if
Expand Down
4 changes: 2 additions & 2 deletions synapse/app/admin_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ class AdminCmdServer(HomeServer):
DATASTORE_CLASS = AdminCmdSlavedStore


async def export_data_command(hs, args):
async def export_data_command(hs: HomeServer, args):
"""Export data for a user.

Args:
hs (HomeServer)
hs
args (argparse.Namespace)
"""

Expand Down
4 changes: 2 additions & 2 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ class KeyUploadServlet(RestServlet):

PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")

def __init__(self, hs):
def __init__(self, hs: HomeServer):
"""
Args:
hs (synapse.server.HomeServer): server
hs: server
"""
super().__init__()
self.auth = hs.get_auth()
Expand Down
2 changes: 1 addition & 1 deletion synapse/app/homeserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
e = e.__cause__


def run(hs):
def run(hs: HomeServer):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:

Expand Down
8 changes: 6 additions & 2 deletions synapse/app/phone_stats_home.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
import math
import resource
import sys
from typing import TYPE_CHECKING

from prometheus_client import Gauge

from synapse.metrics.background_process_metrics import wrap_as_background_process

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger("synapse.app.homeserver")

# Contains the list of processes we will be monitoring
Expand All @@ -41,7 +45,7 @@


@wrap_as_background_process("phone_stats_home")
async def phone_stats_home(hs, stats, stats_process=_stats_process):
async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process):
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time())
uptime = int(now - hs.start_time)
Expand Down Expand Up @@ -142,7 +146,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.warning("Error reporting stats: %s", e)


def start_phone_stats_home(hs):
def start_phone_stats_home(hs: "HomeServer"):
"""
Start the background tasks which report phone home stats.
"""
Expand Down
3 changes: 2 additions & 1 deletion synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

if TYPE_CHECKING:
from synapse.appservice import ApplicationService
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,7 +85,7 @@ class ApplicationServiceApi(SimpleHttpClient):
pushing.
"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()

Expand Down
9 changes: 8 additions & 1 deletion synapse/config/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
import threading
from string import Template
from typing import TYPE_CHECKING

import yaml
from zope.interface import implementer
Expand All @@ -38,6 +39,9 @@

from ._base import Config, ConfigError

if TYPE_CHECKING:
from synapse.server import HomeServer

DEFAULT_LOG_CONFIG = Template(
"""\
# Log configuration for Synapse.
Expand Down Expand Up @@ -306,7 +310,10 @@ def _reload_logging_config(log_config_path):


def setup_logging(
hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
hs: "HomeServer",
config,
use_worker_options=False,
logBeginner: LogBeginner = globalLogBeginner,
) -> None:
"""
Set up the logging subsystem.
Expand Down
7 changes: 6 additions & 1 deletion synapse/federation/federation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import logging
from collections import namedtuple
from typing import TYPE_CHECKING

from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
Expand All @@ -25,11 +26,15 @@
from synapse.http.servlet import assert_params_in_dict
from synapse.types import JsonDict, get_domain_from_id

if TYPE_CHECKING:
from synapse.server import HomeServer


logger = logging.getLogger(__name__)


class FederationBase:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs

self.server_name = hs.hostname
Expand Down
9 changes: 5 additions & 4 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ async def _process_edu(edu_dict: JsonDict) -> None:

async def on_room_state_request(
self, origin: str, room_id: str, event_id: Optional[str]
) -> Tuple[int, Dict[str, Any]]:
) -> Tuple[int, JsonDict]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)

Expand All @@ -481,7 +481,7 @@ async def on_room_state_request(
# - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer.
with (await self._server_linearizer.queue((origin, room_id))):
resp = dict(
resp: JsonDict = dict(
await self._state_resp_cache.wrap(
(room_id, event_id),
self._on_context_state_request_compute,
Expand Down Expand Up @@ -1061,11 +1061,12 @@ async def _process_incoming_pdus_in_room_inner(

origin, event = next

lock = await self.store.try_acquire_lock(
new_lock = await self.store.try_acquire_lock(
_INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
)
if not lock:
if not new_lock:
return
lock = new_lock

def __str__(self) -> str:
return "<ReplicationLayer(%s)>" % self.server_name
Expand Down
8 changes: 6 additions & 2 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import urllib.parse
from io import BytesIO, StringIO
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Generic,
Expand Down Expand Up @@ -73,6 +74,9 @@
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

outgoing_requests_counter = Counter(
Expand Down Expand Up @@ -319,7 +323,7 @@ class MatrixFederationHttpClient:
requests.
"""

def __init__(self, hs, tls_client_options_factory):
def __init__(self, hs: "HomeServer", tls_client_options_factory):
self.hs = hs
self.signing_key = hs.signing_key
self.server_name = hs.hostname
Expand Down Expand Up @@ -711,7 +715,7 @@ def build_auth_headers(
Returns:
A list of headers to be added as "Authorization:" headers
"""
request = {
request: JsonDict = {
"method": method.decode("ascii"),
"uri": url_bytes.decode("ascii"),
"origin": self.server_name,
Expand Down
19 changes: 12 additions & 7 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from http import HTTPStatus
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Expand Down Expand Up @@ -61,6 +62,9 @@
from synapse.util.caches import intern_dict
from synapse.util.iterutils import chunk_seq

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
Expand Down Expand Up @@ -343,6 +347,11 @@ def _send_error_response(
return_json_error(f, request)


_PathEntry = collections.namedtuple(
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)


class JsonResource(DirectServeJsonResource):
"""This implements the HttpServer interface and provides JSON support for
Resources.
Expand All @@ -359,14 +368,10 @@ class JsonResource(DirectServeJsonResource):

isLeaf = True

_PathEntry = collections.namedtuple(
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)

def __init__(self, hs, canonical_json=True, extract_context=False):
def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
self.path_regexs = {}
self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
self.hs = hs

def register_paths(self, method, path_patterns, callback, servlet_classname):
Expand All @@ -391,7 +396,7 @@ def register_paths(self, method, path_patterns, callback, servlet_classname):
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback, servlet_classname)
_PathEntry(path_pattern, callback, servlet_classname)
)

def _get_handler_for_request(
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from synapse.http.server import JsonResource
from synapse.replication.http import (
account_data,
Expand All @@ -26,16 +28,19 @@
streams,
)

if TYPE_CHECKING:
from synapse.server import HomeServer

REPLICATION_PREFIX = "/_synapse/replication"


class ReplicationRestResource(JsonResource):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
# We enable extracting jaeger contexts here as these are internal APIs.
super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)

def register_servlets(self, hs):
def register_servlets(self, hs: "HomeServer"):
send_event.register_servlets(hs, self)
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)
Expand Down
8 changes: 5 additions & 3 deletions synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
import urllib
from inspect import signature
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple

from prometheus_client import Counter, Gauge

Expand Down Expand Up @@ -156,7 +156,7 @@ async def _handle_request(self, request, **kwargs):
pass

@classmethod
def make_client(cls, hs):
def make_client(cls, hs: "HomeServer"):
"""Create a client that makes requests.

Returns a callable that accepts the same parameters as
Expand Down Expand Up @@ -208,7 +208,9 @@ async def send_request(*, instance_name="master", **kwargs):
url_args.append(txn_id)

if cls.METHOD == "POST":
request_func = client.post_json_get_json
request_func: Callable[
..., Awaitable[Any]
] = client.post_json_get_json
elif cls.METHOD == "PUT":
request_func = client.put_json
elif cls.METHOD == "GET":
Expand Down
Loading