Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints for most HomeServer parameters (#11095)
Browse files Browse the repository at this point in the history
  • Loading branch information
squahtx authored Oct 22, 2021
1 parent b9ce53e commit 2b82ec4
Show file tree
Hide file tree
Showing 58 changed files with 342 additions and 143 deletions.
1 change: 1 addition & 0 deletions changelog.d/11095.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to most `HomeServer` parameters.
8 changes: 4 additions & 4 deletions synapse/app/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def listen_ssl(
return r


def refresh_certificate(hs):
def refresh_certificate(hs: "HomeServer"):
"""
Refresh the TLS certificates that Synapse is using by re-reading them from
disk and updating the TLS context factories to use them.
Expand Down Expand Up @@ -419,11 +419,11 @@ def run_sighup(*args, **kwargs):
atexit.register(gc.freeze)


def setup_sentry(hs):
def setup_sentry(hs: "HomeServer"):
"""Enable sentry integration, if enabled in configuration
Args:
hs (synapse.server.HomeServer)
hs
"""

if not hs.config.metrics.sentry_enabled:
Expand All @@ -449,7 +449,7 @@ def setup_sentry(hs):
scope.set_tag("worker_name", name)


def setup_sdnotify(hs):
def setup_sdnotify(hs: "HomeServer"):
"""Adds process state hooks to tell systemd what we are up to."""

# Tell systemd our state, if we're using it. This will silently fail if
Expand Down
4 changes: 2 additions & 2 deletions synapse/app/admin_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ class AdminCmdServer(HomeServer):
DATASTORE_CLASS = AdminCmdSlavedStore


async def export_data_command(hs, args):
async def export_data_command(hs: HomeServer, args):
"""Export data for a user.
Args:
hs (HomeServer)
hs
args (argparse.Namespace)
"""

Expand Down
4 changes: 2 additions & 2 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ class KeyUploadServlet(RestServlet):

PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")

def __init__(self, hs):
def __init__(self, hs: HomeServer):
"""
Args:
hs (synapse.server.HomeServer): server
hs: server
"""
super().__init__()
self.auth = hs.get_auth()
Expand Down
2 changes: 1 addition & 1 deletion synapse/app/homeserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
e = e.__cause__


def run(hs):
def run(hs: HomeServer):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:

Expand Down
8 changes: 6 additions & 2 deletions synapse/app/phone_stats_home.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
import math
import resource
import sys
from typing import TYPE_CHECKING

from prometheus_client import Gauge

from synapse.metrics.background_process_metrics import wrap_as_background_process

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger("synapse.app.homeserver")

# Contains the list of processes we will be monitoring
Expand All @@ -41,7 +45,7 @@


@wrap_as_background_process("phone_stats_home")
async def phone_stats_home(hs, stats, stats_process=_stats_process):
async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process):
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time())
uptime = int(now - hs.start_time)
Expand Down Expand Up @@ -142,7 +146,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.warning("Error reporting stats: %s", e)


def start_phone_stats_home(hs):
def start_phone_stats_home(hs: "HomeServer"):
"""
Start the background tasks which report phone home stats.
"""
Expand Down
3 changes: 2 additions & 1 deletion synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

if TYPE_CHECKING:
from synapse.appservice import ApplicationService
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,7 +85,7 @@ class ApplicationServiceApi(SimpleHttpClient):
pushing.
"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()

Expand Down
9 changes: 8 additions & 1 deletion synapse/config/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
import threading
from string import Template
from typing import TYPE_CHECKING

import yaml
from zope.interface import implementer
Expand All @@ -38,6 +39,9 @@

from ._base import Config, ConfigError

if TYPE_CHECKING:
from synapse.server import HomeServer

DEFAULT_LOG_CONFIG = Template(
"""\
# Log configuration for Synapse.
Expand Down Expand Up @@ -306,7 +310,10 @@ def _reload_logging_config(log_config_path):


def setup_logging(
hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
hs: "HomeServer",
config,
use_worker_options=False,
logBeginner: LogBeginner = globalLogBeginner,
) -> None:
"""
Set up the logging subsystem.
Expand Down
7 changes: 6 additions & 1 deletion synapse/federation/federation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import logging
from collections import namedtuple
from typing import TYPE_CHECKING

from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
Expand All @@ -25,11 +26,15 @@
from synapse.http.servlet import assert_params_in_dict
from synapse.types import JsonDict, get_domain_from_id

if TYPE_CHECKING:
from synapse.server import HomeServer


logger = logging.getLogger(__name__)


class FederationBase:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs

self.server_name = hs.hostname
Expand Down
9 changes: 5 additions & 4 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ async def _process_edu(edu_dict: JsonDict) -> None:

async def on_room_state_request(
self, origin: str, room_id: str, event_id: Optional[str]
) -> Tuple[int, Dict[str, Any]]:
) -> Tuple[int, JsonDict]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)

Expand All @@ -481,7 +481,7 @@ async def on_room_state_request(
# - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer.
with (await self._server_linearizer.queue((origin, room_id))):
resp = dict(
resp: JsonDict = dict(
await self._state_resp_cache.wrap(
(room_id, event_id),
self._on_context_state_request_compute,
Expand Down Expand Up @@ -1061,11 +1061,12 @@ async def _process_incoming_pdus_in_room_inner(

origin, event = next

lock = await self.store.try_acquire_lock(
new_lock = await self.store.try_acquire_lock(
_INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
)
if not lock:
if not new_lock:
return
lock = new_lock

def __str__(self) -> str:
return "<ReplicationLayer(%s)>" % self.server_name
Expand Down
8 changes: 6 additions & 2 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import urllib.parse
from io import BytesIO, StringIO
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Generic,
Expand Down Expand Up @@ -73,6 +74,9 @@
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

outgoing_requests_counter = Counter(
Expand Down Expand Up @@ -319,7 +323,7 @@ class MatrixFederationHttpClient:
requests.
"""

def __init__(self, hs, tls_client_options_factory):
def __init__(self, hs: "HomeServer", tls_client_options_factory):
self.hs = hs
self.signing_key = hs.signing_key
self.server_name = hs.hostname
Expand Down Expand Up @@ -711,7 +715,7 @@ def build_auth_headers(
Returns:
A list of headers to be added as "Authorization:" headers
"""
request = {
request: JsonDict = {
"method": method.decode("ascii"),
"uri": url_bytes.decode("ascii"),
"origin": self.server_name,
Expand Down
19 changes: 12 additions & 7 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from http import HTTPStatus
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Expand Down Expand Up @@ -61,6 +62,9 @@
from synapse.util.caches import intern_dict
from synapse.util.iterutils import chunk_seq

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
Expand Down Expand Up @@ -343,6 +347,11 @@ def _send_error_response(
return_json_error(f, request)


_PathEntry = collections.namedtuple(
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)


class JsonResource(DirectServeJsonResource):
"""This implements the HttpServer interface and provides JSON support for
Resources.
Expand All @@ -359,14 +368,10 @@ class JsonResource(DirectServeJsonResource):

isLeaf = True

_PathEntry = collections.namedtuple(
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)

def __init__(self, hs, canonical_json=True, extract_context=False):
def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
self.path_regexs = {}
self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
self.hs = hs

def register_paths(self, method, path_patterns, callback, servlet_classname):
Expand All @@ -391,7 +396,7 @@ def register_paths(self, method, path_patterns, callback, servlet_classname):
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback, servlet_classname)
_PathEntry(path_pattern, callback, servlet_classname)
)

def _get_handler_for_request(
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from synapse.http.server import JsonResource
from synapse.replication.http import (
account_data,
Expand All @@ -26,16 +28,19 @@
streams,
)

if TYPE_CHECKING:
from synapse.server import HomeServer

REPLICATION_PREFIX = "/_synapse/replication"


class ReplicationRestResource(JsonResource):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
# We enable extracting jaeger contexts here as these are internal APIs.
super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)

def register_servlets(self, hs):
def register_servlets(self, hs: "HomeServer"):
send_event.register_servlets(hs, self)
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)
Expand Down
8 changes: 5 additions & 3 deletions synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
import urllib
from inspect import signature
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple

from prometheus_client import Counter, Gauge

Expand Down Expand Up @@ -156,7 +156,7 @@ async def _handle_request(self, request, **kwargs):
pass

@classmethod
def make_client(cls, hs):
def make_client(cls, hs: "HomeServer"):
"""Create a client that makes requests.
Returns a callable that accepts the same parameters as
Expand Down Expand Up @@ -208,7 +208,9 @@ async def send_request(*, instance_name="master", **kwargs):
url_args.append(txn_id)

if cls.METHOD == "POST":
request_func = client.post_json_get_json
request_func: Callable[
..., Awaitable[Any]
] = client.post_json_get_json
elif cls.METHOD == "PUT":
request_func = client.put_json
elif cls.METHOD == "GET":
Expand Down
Loading

0 comments on commit 2b82ec4

Please sign in to comment.