Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fixup synapse.rest to pass mypy #6732

Merged
merged 7 commits into from
Jan 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/6730.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix changing password via user admin API.
1 change: 1 addition & 0 deletions changelog.d/6731.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix `/events/:event_id` deprecated API.
1 change: 1 addition & 0 deletions changelog.d/6732.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixup `synapse.rest` to pass mypy.
9 changes: 9 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ ignore_missing_imports = True

[mypy-sentry_sdk]
ignore_missing_imports = True

[mypy-PIL.*]
ignore_missing_imports = True

[mypy-lxml]
ignore_missing_imports = True

[mypy-jwt.*]
ignore_missing_imports = True
27 changes: 14 additions & 13 deletions synapse/rest/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ async def on_PUT(self, request, user_id):
raise SynapseError(400, "Invalid password")
else:
new_password = body["password"]
await self._set_password_handler.set_password(
target_user, new_password, requester
await self.set_password_handler.set_password(
target_user.to_string(), new_password, requester
)

if "deactivated" in body:
Expand Down Expand Up @@ -338,21 +338,22 @@ async def on_POST(self, request):

got_mac = body["mac"]

want_mac = hmac.new(
want_mac_builder = hmac.new(
key=self.hs.config.registration_shared_secret.encode(),
digestmod=hashlib.sha1,
)
want_mac.update(nonce.encode("utf8"))
want_mac.update(b"\x00")
want_mac.update(username)
want_mac.update(b"\x00")
want_mac.update(password)
want_mac.update(b"\x00")
want_mac.update(b"admin" if admin else b"notadmin")
want_mac_builder.update(nonce.encode("utf8"))
want_mac_builder.update(b"\x00")
want_mac_builder.update(username)
want_mac_builder.update(b"\x00")
want_mac_builder.update(password)
want_mac_builder.update(b"\x00")
want_mac_builder.update(b"admin" if admin else b"notadmin")
if user_type:
want_mac.update(b"\x00")
want_mac.update(user_type.encode("utf8"))
want_mac = want_mac.hexdigest()
want_mac_builder.update(b"\x00")
want_mac_builder.update(user_type.encode("utf8"))

want_mac = want_mac_builder.hexdigest()

if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
raise SynapseError(403, "HMAC incorrect")
Expand Down
1 change: 1 addition & 0 deletions synapse/rest/client/v1/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self, hs):
super(EventRestServlet, self).__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
self._event_serializer = hs.get_event_client_serializer()

async def on_GET(self, request, event_id):
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def parse_cas_response(self, cas_response_body):
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(
Expand Down
18 changes: 12 additions & 6 deletions synapse/rest/client/v1/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
from typing import List, Optional

from six.moves.urllib import parse as urlparse

Expand Down Expand Up @@ -207,7 +208,7 @@ async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
requester, event_dict, txn_id=txn_id
)

ret = {}
ret = {} # type: dict
if event:
set_tag("event_id", event.event_id)
ret = {"event_id": event.event_id}
Expand Down Expand Up @@ -285,7 +286,7 @@ async def on_POST(self, request, room_identifier, txn_id=None):
try:
remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"]
]
] # type: Optional[List[str]]
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
Expand Down Expand Up @@ -375,7 +376,7 @@ async def on_POST(self, request):
server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request)

limit = int(content.get("limit", 100))
limit = int(content.get("limit", 100)) # type: Optional[int]
since_token = content.get("since", None)
search_filter = content.get("filter", None)

Expand Down Expand Up @@ -504,11 +505,16 @@ async def on_GET(self, request, room_id):
filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
event_filter = Filter(json.loads(filter_json))
if event_filter.filter_json.get("event_format", "client") == "federation":
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
== "federation"
):
as_client_event = False
else:
event_filter = None

msgs = await self.pagination_handler.get_messages(
room_id=room_id,
requester=requester,
Expand Down Expand Up @@ -611,7 +617,7 @@ async def on_GET(self, request, room_id, event_id):
filter_bytes = parse_string(request, "filter")
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes)
event_filter = Filter(json.loads(filter_json))
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
else:
event_filter = None

Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from six import string_types

import synapse
import synapse.api.auth
import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import (
Expand Down Expand Up @@ -405,7 +406,7 @@ async def on_POST(self, request):
return ret
elif kind != b"user":
raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,)
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)

# we do basic sanity checks here because the auth layer will store these
Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/client/v2_alpha/sendtodevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
from typing import Tuple

from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
Expand Down Expand Up @@ -60,7 +61,7 @@ async def _put(self, request, message_type, txn_id):
sender_user_id, message_type, content["messages"]
)

response = (200, {})
response = (200, {}) # type: Tuple[int, dict]
return response


Expand Down
5 changes: 3 additions & 2 deletions synapse/rest/key/v2/remote_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
from typing import Dict, Set

from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json
Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(self, hs):
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
(server,) = request.postpath
query = {server.decode("ascii"): {}}
query = {server.decode("ascii"): {}} # type: dict
elif len(request.postpath) == 2:
server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
Expand Down Expand Up @@ -148,7 +149,7 @@ def query_keys(self, request, query, query_remote_on_cache_miss=False):

time_now_ms = self.clock.time_msec()

cache_misses = dict()
cache_misses = dict() # type: Dict[str, Set[str]]
for (server_name, key_id, from_server), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]

Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/media/v1/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import shutil
from typing import Dict, Tuple

from six import iteritems

Expand Down Expand Up @@ -605,7 +606,7 @@ def _generate_thumbnails(

# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails = {}
thumbnails = {} # type: Dict[Tuple[int, int, str], str]
for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method)
Expand Down
7 changes: 4 additions & 3 deletions synapse/rest/media/v1/preview_url_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import shutil
import sys
import traceback
from typing import Dict, Optional

import six
from six import string_types
Expand Down Expand Up @@ -237,8 +238,8 @@ def _do_preview(self, url, user, ts):
# If we don't find a match, we'll look at the HTTP Content-Type, and
# if that doesn't exist, we'll fall back to UTF-8.
if not encoding:
match = _content_type_match.match(media_info["media_type"])
encoding = match.group(1) if match else "utf-8"
content_match = _content_type_match.match(media_info["media_type"])
encoding = content_match.group(1) if content_match else "utf-8"

og = decode_and_calc_og(body, media_info["uri"], encoding)

Expand Down Expand Up @@ -518,7 +519,7 @@ def _calc_og(tree, media_uri):
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",

og = {}
og = {} # type: Dict[str, Optional[str]]
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if "content" in tag.attrib:
# if we've got more than 50 tags, someone is taking the piss
Expand Down
14 changes: 7 additions & 7 deletions synapse/rest/media/v1/thumbnail_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ def _select_thumbnail(
d_h = desired_height

if desired_method.lower() == "crop":
info_list = []
info_list2 = []
crop_info_list = []
crop_info_list2 = []
for info in thumbnail_infos:
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
Expand All @@ -309,7 +309,7 @@ def _select_thumbnail(
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
info_list.append(
crop_info_list.append(
(
aspect_quality,
min_quality,
Expand All @@ -320,7 +320,7 @@ def _select_thumbnail(
)
)
else:
info_list2.append(
crop_info_list2.append(
(
aspect_quality,
min_quality,
Expand All @@ -330,10 +330,10 @@ def _select_thumbnail(
info,
)
)
if info_list:
return min(info_list)[-1]
if crop_info_list:
return min(crop_info_list2)[-1]
else:
return min(info_list2)[-1]
return min(crop_info_list2)[-1]
else:
info_list = []
info_list2 = []
Expand Down
13 changes: 13 additions & 0 deletions tests/rest/admin/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,19 @@ def test_requester_is_admin(self):
self.assertEqual(0, channel.json_body["is_guest"])
self.assertEqual(0, channel.json_body["deactivated"])

# Change password
body = json.dumps({"password": "hahaha"})

request, channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
self.render(request)

self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

# Modify user
body = json.dumps({"displayname": "foobar", "deactivated": True})

Expand Down
27 changes: 27 additions & 0 deletions tests/rest/client/v1/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,30 @@ def TODO_test_stream_items(self):

# someone else set topic, expect 6 (join,send,topic,join,send,topic)
pass


class GetEventsTestCase(unittest.HomeserverTestCase):
servlets = [
events.register_servlets,
room.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
]

def prepare(self, hs, reactor, clock):

# register an account
self.user_id = self.register_user("sid1", "pass")
self.token = self.login(self.user_id, "pass")

self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)

def test_get_event_via_events(self):
resp = self.helper.send(self.room_id, tok=self.token)
event_id = resp["event_id"]

request, channel = self.make_request(
"GET", "/events/" + event_id, access_token=self.token,
)
self.render(request)
self.assertEquals(channel.code, 200, msg=channel.result)
2 changes: 1 addition & 1 deletion tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def register_user(self, username, password, admin=False):
# Create the user
request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, 200, msg=channel.result)
nonce = channel.json_body["nonce"]

want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
Expand Down
3 changes: 1 addition & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ commands = mypy \
synapse/logging/ \
synapse/module_api \
synapse/replication \
synapse/rest/consent \
synapse/rest/saml2 \
synapse/rest \
synapse/spam_checker_api \
synapse/storage/engines \
synapse/streams
Expand Down