Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints to synapse.http. #11571

Merged
merged 5 commits into from
Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11571.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to `synapse.http`.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ disallow_untyped_defs = False
[mypy-synapse.handlers.*]
disallow_untyped_defs = True

[mypy-synapse.http.server]
disallow_untyped_defs = True

[mypy-synapse.metrics.*]
disallow_untyped_defs = True

Expand Down
6 changes: 3 additions & 3 deletions synapse/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request"""

def __init__(self, msg):
def __init__(self, msg: str):
super().__init__(504, msg)


ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")


def redact_uri(uri):
def redact_uri(uri: str) -> str:
"""Strips sensitive information from the uri replaces with <redacted>"""
uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri)
Expand All @@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
https://twistedmatrix.com/trac/ticket/6528
"""

def stopProducing(self):
def stopProducing(self) -> None:
try:
FileBodyProducer.stopProducing(self)
except task.TaskStopped:
Expand Down
12 changes: 8 additions & 4 deletions synapse/http/additional_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple

from twisted.web.server import Request

Expand All @@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource):
and exception handling.
"""

def __init__(self, hs: "HomeServer", handler):
def __init__(
self,
hs: "HomeServer",
handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]],
):
"""Initialise AdditionalResource

The ``handler`` should return a deferred which completes when it has
Expand All @@ -47,7 +51,7 @@ def __init__(self, hs: "HomeServer", handler):
super().__init__()
self._handler = handler

def _async_render(self, request: Request):
async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]:
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
return self._handler(request)
return await self._handler(request)
90 changes: 53 additions & 37 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Iterable,
Iterator,
List,
NoReturn,
Optional,
Pattern,
Tuple,
Expand Down Expand Up @@ -170,7 +171,9 @@ def return_html_error(
respond_with_html(request, code, body)


def wrap_async_request_handler(h):
def wrap_async_request_handler(
h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]]
) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]:
"""Wraps an async request handler so that it calls request.processing.

This helps ensure that work done by the request handler after the request is completed
Expand All @@ -183,7 +186,9 @@ def wrap_async_request_handler(h):
logged until the deferred completes.
"""

async def wrapped_async_request_handler(self, request):
async def wrapped_async_request_handler(
self: "_AsyncResource", request: SynapseRequest
) -> None:
with request.processing():
await h(self, request)

Expand Down Expand Up @@ -240,18 +245,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
context from the request the servlet is handling.
"""

def __init__(self, extract_context=False):
def __init__(self, extract_context: bool = False):
super().__init__()

self._extract_context = extract_context

def render(self, request):
def render(self, request: SynapseRequest) -> int:
"""This gets called by twisted every time someone sends us a request."""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET

@wrap_async_request_handler
async def _async_render_wrapper(self, request: SynapseRequest):
async def _async_render_wrapper(self, request: SynapseRequest) -> None:
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
Expand All @@ -271,7 +276,7 @@ async def _async_render_wrapper(self, request: SynapseRequest):
f = failure.Failure()
self._send_error_response(f, request)

async def _async_render(self, request: Request):
async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]:
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overridden in sub classes for
different routing.
Expand Down Expand Up @@ -318,7 +323,7 @@ class DirectServeJsonResource(_AsyncResource):
formatting responses and errors as JSON.
"""

def __init__(self, canonical_json=False, extract_context=False):
def __init__(self, canonical_json: bool = False, extract_context: bool = False):
super().__init__(extract_context)
self.canonical_json = canonical_json

Expand All @@ -327,7 +332,7 @@ def _send_response(
request: SynapseRequest,
code: int,
response_object: Any,
):
) -> None:
"""Implements _AsyncResource._send_response"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
Expand Down Expand Up @@ -368,34 +373,45 @@ class JsonResource(DirectServeJsonResource):

isLeaf = True

def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
def __init__(
self,
hs: "HomeServer",
canonical_json: bool = True,
extract_context: bool = False,
):
super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
self.hs = hs

def register_paths(self, method, path_patterns, callback, servlet_classname):
def register_paths(
self,
method: str,
path_patterns: Iterable[Pattern],
callback: ServletCallback,
servlet_classname: str,
) -> None:
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
handler for that request.

Args:
method (str): GET, POST etc
method: GET, POST etc

path_patterns (Iterable[str]): A list of regular expressions to which
the request URLs are compared.
path_patterns: A list of regular expressions to which the request
URLs are compared.

callback (function): The handler for the request. Usually a Servlet
callback: The handler for the request. Usually a Servlet

servlet_classname (str): The name of the handler to be used in prometheus
servlet_classname: The name of the handler to be used in prometheus
and opentracing logs.
"""
method = method.encode("utf-8") # method is bytes on py3
method_bytes = method.encode("utf-8")

for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self.path_regexs.setdefault(method_bytes, []).append(
_PathEntry(path_pattern, callback, servlet_classname)
)

Expand Down Expand Up @@ -427,7 +443,7 @@ def _get_handler_for_request(
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, "unrecognised_request_handler", {}

async def _async_render(self, request):
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)

# Make sure we have an appropriate name for this handler in prometheus
Expand Down Expand Up @@ -468,7 +484,7 @@ def _send_response(
request: SynapseRequest,
code: int,
response_object: Any,
):
) -> None:
"""Implements _AsyncResource._send_response"""
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
Expand All @@ -492,35 +508,35 @@ class StaticResource(File):
Differs from the File resource by adding clickjacking protection.
"""

def render_GET(self, request: Request):
def render_GET(self, request: Request) -> bytes:
set_clickjacking_protection_headers(request)
return super().render_GET(request)


def _unrecognised_request_handler(request):
def _unrecognised_request_handler(request: Request) -> NoReturn:
"""Request handler for unrecognised requests

This is a request handler suitable for return from
_get_handler_for_request. It actually just raises an
UnrecognizedRequestError.

Args:
request (twisted.web.http.Request):
request: Unused, but passed in to match the signature of ServletCallback.
"""
raise UnrecognizedRequestError()


class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path."""

def __init__(self, path):
def __init__(self, path: str):
resource.Resource.__init__(self)
self.url = path

def render_GET(self, request):
def render_GET(self, request: Request) -> bytes:
return redirectTo(self.url.encode("ascii"), request)

def getChild(self, name, request):
def getChild(self, name: str, request: Request) -> resource.Resource:
if len(name) == 0:
return self # select ourselves as the child to render
return resource.Resource.getChild(self, name, request)
Expand All @@ -529,15 +545,15 @@ def getChild(self, name, request):
class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: Request) -> bytes:
request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0")

set_cors_headers(request)

return b""

def getChildWithDefault(self, path, request):
def getChildWithDefault(self, path: str, request: Request) -> resource.Resource:
if request.method == b"OPTIONS":
return self # select ourselves as the child to render
return resource.Resource.getChildWithDefault(self, path, request)
Expand Down Expand Up @@ -649,7 +665,7 @@ def respond_with_json(
json_object: Any,
send_cors: bool = False,
canonical_json: bool = True,
):
) -> Optional[int]:
"""Sends encoded JSON in response to the given request.

Args:
Expand Down Expand Up @@ -696,7 +712,7 @@ def respond_with_json_bytes(
code: int,
json_bytes: bytes,
send_cors: bool = False,
):
) -> Optional[int]:
"""Sends encoded JSON in response to the given request.

Args:
Expand All @@ -713,7 +729,7 @@ def respond_with_json_bytes(
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return
return None

request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
Expand All @@ -731,7 +747,7 @@ async def _async_write_json_to_request_in_thread(
request: SynapseRequest,
json_encoder: Callable[[Any], bytes],
json_object: Any,
):
) -> None:
"""Encodes the given JSON object on a thread and then writes it to the
request.

Expand Down Expand Up @@ -773,7 +789,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
_ByteProducer(request, bytes_generator)


def set_cors_headers(request: Request):
def set_cors_headers(request: Request) -> None:
"""Set the CORS headers so that javascript running in a web browsers can
use this API

Expand All @@ -790,14 +806,14 @@ def set_cors_headers(request: Request):
)


def respond_with_html(request: Request, code: int, html: str):
def respond_with_html(request: Request, code: int, html: str) -> None:
"""
Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes.
"""
respond_with_html_bytes(request, code, html.encode("utf-8"))


def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None:
"""
Sends HTML (encoded as UTF-8 bytes) as the response to the given request.

Expand All @@ -815,7 +831,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return
return None

request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
Expand All @@ -828,7 +844,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
finish_request(request)


def set_clickjacking_protection_headers(request: Request):
def set_clickjacking_protection_headers(request: Request) -> None:
"""
Set headers to guard against clickjacking of embedded content.

Expand All @@ -850,7 +866,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None:
finish_request(request)


def finish_request(request: Request):
def finish_request(request: Request) -> None:
"""Finish writing the response to the request.

Twisted throws a RuntimeException if the connection closed before the
Expand Down
3 changes: 2 additions & 1 deletion synapse/http/servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from twisted.web.server import Request

from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID
from synapse.util import json_decoder

Expand Down Expand Up @@ -726,7 +727,7 @@ class attribute containing a pre-compiled regular expression. The automatic
into the appropriate HTTP response.
"""

def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
"""Register this servlet with the given HTTP server."""
patterns = getattr(self, "PATTERNS", None)
if patterns:
Expand Down
Loading