Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to sydent/http/servlets #361

Merged
merged 15 commits into from
Jun 21, 2021
1 change: 1 addition & 0 deletions changelog.d/361.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to sydent/http/servlets/ to support mypy type checking.
23 changes: 15 additions & 8 deletions sydent/http/servlets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
import functools
import json
import logging
from typing import Any, Dict, Iterable

from twisted.internet import defer
from twisted.python.failure import Failure
from twisted.web import server
from twisted.web.server import Request

from sydent.types import JsonDict
from sydent.util import json_decoder

logger = logging.getLogger(__name__)
Expand All @@ -38,7 +42,9 @@ def __init__(self, httpStatus, errcode, error):
self.error = error


def get_args(request, args, required=True):
def get_args(
request: Request, args: Iterable[str], required: bool = True
) -> Dict[str, Any]:
"""
Helper function to get arguments for an HTTP request.
Currently takes args from the top level keys of a json object or
Expand All @@ -50,7 +56,7 @@ def get_args(request, args, required=True):
:param request: The request received by the servlet.
:type request: twisted.web.server.Request
:param args: The args to look for in the request's parameters.
:type args: tuple[unicode]
:type args: Iterable[str]
:param required: Whether to raise a MatrixRestError with 400
M_MISSING_PARAMS if an argument is not found.
:type required: bool
Expand All @@ -62,6 +68,7 @@ def get_args(request, args, required=True):
are of type unicode.
:rtype: dict[unicode, any]
"""
assert request.path is not None
v1_path = request.path.startswith(b"/_matrix/identity/api/v1")

request_args = None
Expand Down Expand Up @@ -126,7 +133,7 @@ def get_args(request, args, required=True):

def jsonwrap(f):
@functools.wraps(f)
def inner(self, request, *args, **kwargs):
def inner(self, request: Request, *args, **kwargs) -> bytes:
"""
Runs a web handler function with the given request and parameters, then
converts its result into JSON and returns it. If an error happens, also sets
Expand Down Expand Up @@ -162,7 +169,7 @@ def inner(self, request, *args, **kwargs):


def deferjsonwrap(f):
def reqDone(resp, request):
def reqDone(resp: Dict[str, Any], request: Request) -> None:
"""
Converts the given response content into JSON and encodes it to bytes, then
writes it as the response to the given request with the right headers.
Expand All @@ -176,7 +183,7 @@ def reqDone(resp, request):
request.write(dict_to_json_bytes(resp))
request.finish()

def reqErr(failure, request):
def reqErr(failure: Failure, request: Request) -> None:
"""
Logs the given failure. If the failure is a MatrixRestError, writes a response
using the info it contains, otherwise responds with 500 Internal Server Error.
Expand Down Expand Up @@ -206,7 +213,7 @@ def reqErr(failure, request):
)
request.finish()

def inner(*args, **kwargs):
def inner(*args, **kwargs) -> int:
"""
Runs an asynchronous web handler function with the given arguments and add
reqDone and reqErr as the resulting Deferred's callbacks.
Expand All @@ -228,13 +235,13 @@ def inner(*args, **kwargs):
return inner


def send_cors(request):
def send_cors(request: Request) -> None:
request.setHeader("Access-Control-Allow-Origin", "*")
request.setHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
request.setHeader("Access-Control-Allow-Headers", "*")


def dict_to_json_bytes(content):
def dict_to_json_bytes(content: JsonDict) -> bytes:
"""
Converts a dict into JSON and encodes it to bytes.

Expand Down
13 changes: 10 additions & 3 deletions sydent/http/servlets/accountservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from twisted.web.resource import Resource
from twisted.web.server import Request

from sydent.http.auth import authV2
from sydent.http.servlets import jsonwrap, send_cors
from sydent.types import JsonDict

if TYPE_CHECKING:
from sydent.sydent import Sydent


class AccountServlet(Resource):
isLeaf = False

def __init__(self, syd):
def __init__(self, syd: "Sydent") -> None:
Resource.__init__(self)
self.sydent = syd

@jsonwrap
def render_GET(self, request):
def render_GET(self, request: Request) -> JsonDict:
"""
Return information about the user's account
(essentially just a 'who am i')
Expand All @@ -39,6 +46,6 @@ def render_GET(self, request):
"user_id": account.userId,
}

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: Request) -> bytes:
send_cors(request)
return b""
13 changes: 10 additions & 3 deletions sydent/http/servlets/authenticated_bind_threepid_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from twisted.web.resource import Resource
from twisted.web.server import Request

from sydent.http.servlets import get_args, jsonwrap, send_cors
from sydent.types import JsonDict

if TYPE_CHECKING:
from sydent.sydent import Sydent


class AuthenticatedBindThreePidServlet(Resource):
Expand All @@ -23,12 +30,12 @@ class AuthenticatedBindThreePidServlet(Resource):
It is assumed that authentication happens out of band
"""

def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
Resource.__init__(self)
self.sydent = sydent

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: Request) -> JsonDict:
send_cors(request)
args = get_args(request, ("medium", "address", "mxid"))

Expand All @@ -38,6 +45,6 @@ def render_POST(self, request):
args["mxid"],
)

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: Request) -> bytes:
send_cors(request)
return b""
12 changes: 9 additions & 3 deletions sydent/http/servlets/authenticated_unbind_threepid_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from twisted.web.resource import Resource
from twisted.web.server import Request

from sydent.http.servlets import get_args, jsonwrap, send_cors

if TYPE_CHECKING:
from sydent.sydent import Sydent


class AuthenticatedUnbindThreePidServlet(Resource):
"""A servlet which allows a caller to unbind any 3pid they want from an mxid

It is assumed that authentication happens out of band
"""

def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
Resource.__init__(self)
self.sydent = sydent

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: Request) -> None:
send_cors(request)
args = get_args(request, ("medium", "address", "mxid"))

Expand All @@ -39,6 +45,6 @@ def render_POST(self, request):
args["mxid"],
)

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: Request) -> bytes:
send_cors(request)
return b""
12 changes: 9 additions & 3 deletions sydent/http/servlets/blindlysignstuffservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,35 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

import signedjson.key
import signedjson.sign
from twisted.web.resource import Resource
from twisted.web.server import Request

from sydent.db.invite_tokens import JoinTokenStore
from sydent.http.auth import authV2
from sydent.http.servlets import MatrixRestError, get_args, jsonwrap, send_cors
from sydent.types import JsonDict

if TYPE_CHECKING:
from sydent.sydent import Sydent

logger = logging.getLogger(__name__)


class BlindlySignStuffServlet(Resource):
isLeaf = True

def __init__(self, syd, require_auth=False):
def __init__(self, syd: "Sydent", require_auth: bool = False) -> None:
self.sydent = syd
self.server_name = syd.server_name
self.tokenStore = JoinTokenStore(syd)
self.require_auth = require_auth

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: Request) -> JsonDict:
send_cors(request)

if self.require_auth:
Expand Down Expand Up @@ -67,6 +73,6 @@ def render_POST(self, request):

return signed

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: Request) -> bytes:
send_cors(request)
return b""
12 changes: 9 additions & 3 deletions sydent/http/servlets/bulklookupservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,29 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

from twisted.web.resource import Resource
from twisted.web.server import Request

from sydent.db.threepid_associations import GlobalAssociationStore
from sydent.http.servlets import MatrixRestError, get_args, jsonwrap, send_cors
from sydent.types import JsonDict

if TYPE_CHECKING:
from sydent.sydent import Sydent

logger = logging.getLogger(__name__)


class BulkLookupServlet(Resource):
isLeaf = True

def __init__(self, syd):
def __init__(self, syd: "Sydent") -> None:
self.sydent = syd

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: Request) -> JsonDict:
"""
Bulk-lookup for threepids.
Params: 'threepids': list of threepids, each of which is a list of medium, address
Expand All @@ -53,6 +59,6 @@ def render_POST(self, request):

return {"threepids": results}

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: Request) -> bytes:
send_cors(request)
return b""
24 changes: 16 additions & 8 deletions sydent/http/servlets/emailservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from twisted.web.resource import Resource
from twisted.web.server import Request

from sydent.http.auth import authV2
from sydent.http.servlets import get_args, jsonwrap, send_cors
from sydent.types import JsonDict
from sydent.util.emailutils import EmailAddressException, EmailSendException
from sydent.util.stringutils import MAX_EMAIL_ADDRESS_LENGTH, is_valid_client_secret
from sydent.validators import (
Expand All @@ -25,16 +29,19 @@
SessionExpiredException,
)

if TYPE_CHECKING:
from sydent.sydent import Sydent


class EmailRequestCodeServlet(Resource):
isLeaf = True

def __init__(self, syd, require_auth=False):
def __init__(self, syd: "Sydent", require_auth: bool = False) -> None:
self.sydent = syd
self.require_auth = require_auth

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: Request) -> JsonDict:
send_cors(request)

if self.require_auth:
Expand Down Expand Up @@ -83,19 +90,19 @@ def render_POST(self, request):

return resp

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: Request) -> bytes:
send_cors(request)
return b""


class EmailValidateCodeServlet(Resource):
isLeaf = True

def __init__(self, syd, require_auth=False):
def __init__(self, syd: "Sydent", require_auth: bool = False) -> None:
self.sydent = syd
self.require_auth = require_auth

def render_GET(self, request):
def render_GET(self, request: Request) -> bytes:
args = get_args(request, ("nextLink",), required=False)

resp = None
Expand All @@ -122,18 +129,19 @@ def render_GET(self, request):

request.setHeader("Content-Type", "text/html")
res = open(templateFile).read() % {"message": msg}

return res.encode("UTF-8")

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: Request) -> JsonDict:
send_cors(request)

if self.require_auth:
authV2(self.sydent, request)

return self.do_validate_request(request)

def do_validate_request(self, request):
def do_validate_request(self, request: Request) -> JsonDict:
"""
Extracts information about a validation session from the request and
attempts to validate that session.
Expand Down Expand Up @@ -188,6 +196,6 @@ def do_validate_request(self, request):
"error": "No session could be found with this sid",
}

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: Request) -> bytes:
send_cors(request)
return b""
Loading