Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to sydent/http/servlets #361

Merged
merged 15 commits into from
Jun 21, 2021
1 change: 1 addition & 0 deletions changelog.d/361.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to sydent/http/servlets/ to support mypy type checking.
20 changes: 13 additions & 7 deletions sydent/http/servlets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
import functools
import json
import logging
from typing import TYPE_CHECKING, Any, Dict, Tuple

from twisted.internet import defer
from twisted.web import server

from sydent.util import json_decoder

if TYPE_CHECKING:
from twisted.python.failure import Failure
from twisted.web.server import Request
H-Shay marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.getLogger(__name__)


Expand All @@ -38,7 +43,7 @@ def __init__(self, httpStatus, errcode, error):
self.error = error


def get_args(request, args, required=True):
def get_args(request: "Request", args: Tuple, required: bool = True) -> Dict[str, Any]:
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
"""
Helper function to get arguments for an HTTP request.
Currently takes args from the top level keys of a json object or
Expand All @@ -62,6 +67,7 @@ def get_args(request, args, required=True):
are of type unicode.
:rtype: dict[unicode, any]
"""
assert request.path is not None
v1_path = request.path.startswith(b"/_matrix/identity/api/v1")

request_args = None
Expand Down Expand Up @@ -126,7 +132,7 @@ def get_args(request, args, required=True):

def jsonwrap(f):
@functools.wraps(f)
def inner(self, request, *args, **kwargs):
def inner(self, request: "Request", *args, **kwargs) -> bytes:
"""
Runs a web handler function with the given request and parameters, then
converts its result into JSON and returns it. If an error happens, also sets
Expand Down Expand Up @@ -162,7 +168,7 @@ def inner(self, request, *args, **kwargs):


def deferjsonwrap(f):
def reqDone(resp, request):
def reqDone(resp: Dict[str, Any], request: "Request") -> None:
"""
Converts the given response content into JSON and encodes it to bytes, then
writes it as the response to the given request with the right headers.
Expand All @@ -176,7 +182,7 @@ def reqDone(resp, request):
request.write(dict_to_json_bytes(resp))
request.finish()

def reqErr(failure, request):
def reqErr(failure: "Failure", request: "Request") -> None:
"""
Logs the given failure. If the failure is a MatrixRestError, writes a response
using the info it contains, otherwise responds with 500 Internal Server Error.
Expand Down Expand Up @@ -206,7 +212,7 @@ def reqErr(failure, request):
)
request.finish()

def inner(*args, **kwargs):
def inner(*args, **kwargs) -> int:
"""
Runs an asynchronous web handler function with the given arguments and add
reqDone and reqErr as the resulting Deferred's callbacks.
Expand All @@ -228,13 +234,13 @@ def inner(*args, **kwargs):
return inner


def send_cors(request):
def send_cors(request: "Request") -> None:
request.setHeader("Access-Control-Allow-Origin", "*")
request.setHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
request.setHeader("Access-Control-Allow-Headers", "*")


def dict_to_json_bytes(content):
def dict_to_json_bytes(content: Dict[Any, Any]) -> bytes:
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
"""
Converts a dict into JSON and encodes it to bytes.

Expand Down
13 changes: 10 additions & 3 deletions sydent/http/servlets/accountservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict

from twisted.web.resource import Resource

from sydent.http.auth import authV2
from sydent.http.servlets import jsonwrap, send_cors

if TYPE_CHECKING:
from twisted.web.server import Request

from sydent.sydent import Sydent


class AccountServlet(Resource):
isLeaf = False

def __init__(self, syd):
def __init__(self, syd: "Sydent") -> None:
Resource.__init__(self)
self.sydent = syd

@jsonwrap
def render_GET(self, request):
def render_GET(self, request: "Request") -> Dict[str, str]:
"""
Return information about the user's account
(essentially just a 'who am i')
Expand All @@ -39,6 +46,6 @@ def render_GET(self, request):
"user_id": account.userId,
}

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: "Request") -> bytes:
send_cors(request)
return b""
13 changes: 10 additions & 3 deletions sydent/http/servlets/authenticated_bind_threepid_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Dict

from twisted.web.resource import Resource

from sydent.http.servlets import get_args, jsonwrap, send_cors

if TYPE_CHECKING:
from twisted.web.server import Request

from sydent.sydent import Sydent


class AuthenticatedBindThreePidServlet(Resource):
"""A servlet which allows a caller to bind any 3pid they want to an mxid

It is assumed that authentication happens out of band
"""

def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
Resource.__init__(self)
self.sydent = sydent

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: "Request") -> Dict[str, Any]:
send_cors(request)
args = get_args(request, ("medium", "address", "mxid"))

Expand All @@ -38,6 +45,6 @@ def render_POST(self, request):
args["mxid"],
)

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: "Request") -> bytes:
send_cors(request)
return b""
13 changes: 10 additions & 3 deletions sydent/http/servlets/authenticated_unbind_threepid_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from twisted.web.resource import Resource

from sydent.http.servlets import get_args, jsonwrap, send_cors

if TYPE_CHECKING:
from twisted.web.server import Request

from sydent.sydent import Sydent


class AuthenticatedUnbindThreePidServlet(Resource):
"""A servlet which allows a caller to unbind any 3pid they want from an mxid

It is assumed that authentication happens out of band
"""

def __init__(self, sydent):
def __init__(self, sydent: "Sydent") -> None:
Resource.__init__(self)
self.sydent = sydent

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: "Request") -> None:
send_cors(request)
args = get_args(request, ("medium", "address", "mxid"))

Expand All @@ -39,6 +46,6 @@ def render_POST(self, request):
args["mxid"],
)

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: "Request") -> bytes:
send_cors(request)
return b""
12 changes: 9 additions & 3 deletions sydent/http/servlets/blindlysignstuffservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

import signedjson.key
import signedjson.sign
Expand All @@ -22,20 +23,25 @@
from sydent.http.auth import authV2
from sydent.http.servlets import MatrixRestError, get_args, jsonwrap, send_cors

if TYPE_CHECKING:
from twisted.web.server import Request

from sydent.sydent import Sydent

logger = logging.getLogger(__name__)


class BlindlySignStuffServlet(Resource):
isLeaf = True

def __init__(self, syd, require_auth=False):
def __init__(self, syd: "Sydent", require_auth: bool = False) -> None:
self.sydent = syd
self.server_name = syd.server_name
self.tokenStore = JoinTokenStore(syd)
self.require_auth = require_auth

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: "Request"):
send_cors(request)

if self.require_auth:
Expand Down Expand Up @@ -67,6 +73,6 @@ def render_POST(self, request):

return signed

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: "Request") -> bytes:
send_cors(request)
return b""
12 changes: 9 additions & 3 deletions sydent/http/servlets/bulklookupservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,29 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, List, Tuple

from twisted.web.resource import Resource

from sydent.db.threepid_associations import GlobalAssociationStore
from sydent.http.servlets import MatrixRestError, get_args, jsonwrap, send_cors

if TYPE_CHECKING:
from twisted.web.server import Request

from sydent.sydent import Sydent

logger = logging.getLogger(__name__)


class BulkLookupServlet(Resource):
isLeaf = True

def __init__(self, syd):
def __init__(self, syd: "Sydent") -> None:
self.sydent = syd

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: "Request") -> Dict[str, List[Tuple[str, str, str]]]:
"""
Bulk-lookup for threepids.
Params: 'threepids': list of threepids, each of which is a list of medium, address
Expand All @@ -53,6 +59,6 @@ def render_POST(self, request):

return {"threepids": results}

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: "Request") -> bytes:
send_cors(request)
return b""
24 changes: 16 additions & 8 deletions sydent/http/servlets/emailservlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict, Union

from twisted.web.resource import Resource

from sydent.http.auth import authV2
Expand All @@ -25,16 +27,21 @@
SessionExpiredException,
)

if TYPE_CHECKING:
from twisted.web.server import Request

from sydent.sydent import Sydent


class EmailRequestCodeServlet(Resource):
isLeaf = True

def __init__(self, syd, require_auth=False):
def __init__(self, syd: "Sydent", require_auth: bool = False) -> None:
self.sydent = syd
self.require_auth = require_auth

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: "Request") -> Dict:
send_cors(request)

if self.require_auth:
Expand Down Expand Up @@ -83,19 +90,19 @@ def render_POST(self, request):

return resp

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: "Request") -> bytes:
send_cors(request)
return b""


class EmailValidateCodeServlet(Resource):
isLeaf = True

def __init__(self, syd, require_auth=False):
def __init__(self, syd: "Sydent", require_auth: bool = False) -> None:
self.sydent = syd
self.require_auth = require_auth

def render_GET(self, request):
def render_GET(self, request: "Request") -> bytes:
args = get_args(request, ("nextLink",), required=False)

resp = None
Expand All @@ -122,18 +129,19 @@ def render_GET(self, request):

request.setHeader("Content-Type", "text/html")
res = open(templateFile).read() % {"message": msg}

return res.encode("UTF-8")

@jsonwrap
def render_POST(self, request):
def render_POST(self, request: "Request") -> Dict[str, Union[bool, str]]:
send_cors(request)

if self.require_auth:
authV2(self.sydent, request)

return self.do_validate_request(request)

def do_validate_request(self, request):
def do_validate_request(self, request: "Request") -> Dict[str, Union[bool, str]]:
"""
Extracts information about a validation session from the request and
attempts to validate that session.
Expand Down Expand Up @@ -188,6 +196,6 @@ def do_validate_request(self, request):
"error": "No session could be found with this sid",
}

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: "Request") -> bytes:
send_cors(request)
return b""
Loading