Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Pass a Site into make_request #8757

Merged
merged 7 commits into from
Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8757.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.
13 changes: 7 additions & 6 deletions tests/app/test_frontend_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from synapse.app.generic_worker import GenericWorkerServer

from tests.server import make_request, render
from tests.unittest import HomeserverTestCase


Expand Down Expand Up @@ -55,10 +56,10 @@ def test_listen_http_with_presence_enabled(self):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"]
resource = site.resource.children[b"_matrix"].children[b"client"]

request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
render(request, resource, self.reactor)

# 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400)
Expand All @@ -77,10 +78,10 @@ def test_listen_http_with_presence_disabled(self):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"]
resource = site.resource.children[b"_matrix"].children[b"client"]

request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
render(request, resource, self.reactor)

# 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401)
Expand Down
17 changes: 9 additions & 8 deletions tests/app/test_openid_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def

from tests.server import make_request, render
from tests.unittest import HomeserverTestCase


Expand Down Expand Up @@ -66,16 +67,16 @@ def test_openid_listener(self, names, expectation):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"]
resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise

request, channel = self.make_request(
"GET", "/_matrix/federation/v1/openid/userinfo"
request, channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
self.render(request)
render(request, resource, self.reactor)

self.assertEqual(channel.code, 401)

Expand Down Expand Up @@ -115,15 +116,15 @@ def test_openid_listener(self, names, expectation):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"]
resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise

request, channel = self.make_request(
"GET", "/_matrix/federation/v1/openid/userinfo"
request, channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
self.render(request)
render(request, resource, self.reactor)

self.assertEqual(channel.code, 401)
13 changes: 7 additions & 6 deletions tests/http/test_additional_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json

from tests.server import FakeSite, make_request, render
from tests.unittest import HomeserverTestCase


Expand All @@ -43,20 +44,20 @@ class AdditionalResourceTests(HomeserverTestCase):

def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
resource = AdditionalResource(self.hs, handler)

request, channel = self.make_request("GET", "/")
self.render(request)
request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
render(request, resource, self.reactor)

self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})

def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
resource = AdditionalResource(self.hs, handler)

request, channel = self.make_request("GET", "/")
self.render(request)
request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
render(request, resource, self.reactor)

self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
29 changes: 22 additions & 7 deletions tests/replication/test_client_reader_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel
from tests.server import FakeChannel, make_request

logger = logging.getLogger(__name__)

Expand All @@ -46,8 +46,11 @@ def test_register_single_worker(self):
"""Test that registration works when using a single client reader worker.
"""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]

request_1, channel_1 = self.make_request(
request_1, channel_1 = make_request(
self.reactor,
site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
Expand All @@ -59,8 +62,12 @@ def test_register_single_worker(self):
session = channel_1.json_body["session"]

# also complete the dummy auth
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
request_2, channel_2 = make_request(
self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200)
Expand All @@ -74,7 +81,10 @@ def test_register_multi_worker(self):
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")

request_1, channel_1 = self.make_request(
site_1 = self._hs_to_site[worker_hs_1]
request_1, channel_1 = make_request(
self.reactor,
site_1,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
Expand All @@ -86,8 +96,13 @@ def test_register_multi_worker(self):
session = channel_1.json_body["session"]

# also complete the dummy auth
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
site_2 = self._hs_to_site[worker_hs_2]
request_2, channel_2 = make_request(
self.reactor,
site_2,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200)
Expand Down
10 changes: 6 additions & 4 deletions tests/replication/test_multi_media_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport
from tests.server import FakeChannel, FakeSite, FakeTransport, make_request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,14 +67,16 @@ def _get_media_req(
The channel for the *client* request and the *outbound* request for
the media which the caller should respond to.
"""

request, channel = self.make_request(
resource = hs.get_media_repository_resource().children[b"download"]
request, channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
"/{}/{}".format(target, media_id),
shorthand=False,
access_token=self.access_token,
)
request.render(hs.get_media_repository_resource().children[b"download"])
request.render(resource)
self.pump()

clients = self.reactor.tcpClients
Expand Down
42 changes: 32 additions & 10 deletions tests/replication/test_sharded_event_persister.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from synapse.rest.client.v2_alpha import sync

from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
from tests.utils import USE_POSTGRES_FOR_TESTS

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -148,6 +149,7 @@ def test_vector_clock_token(self):
sync_hs = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "sync"},
)
sync_hs_site = self._hs_to_site[sync_hs]

# Specially selected room IDs that get persisted on different workers.
room_id1 = "!foo:test"
Expand Down Expand Up @@ -178,7 +180,9 @@ def test_vector_clock_token(self):
)

# Do an initial sync so that we're up to date.
request, channel = self.make_request("GET", "/sync", access_token=access_token)
request, channel = make_request(
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
)
self.render_on_worker(sync_hs, request)
next_batch = channel.json_body["next_batch"]

Expand All @@ -203,8 +207,12 @@ def test_vector_clock_token(self):

# Check that syncing still gets the new event, despite the gap in the
# stream IDs.
request, channel = self.make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)

Expand All @@ -230,7 +238,9 @@ def test_vector_clock_token(self):
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"]

request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(vector_clock_token),
access_token=access_token,
Expand All @@ -254,8 +264,12 @@ def test_vector_clock_token(self):
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)

request, channel = self.make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)

Expand All @@ -269,7 +283,9 @@ def test_vector_clock_token(self):
# Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion.
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id1, prev_batch1, vector_clock_token
Expand All @@ -281,7 +297,9 @@ def test_vector_clock_token(self):

# Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken.
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id2, prev_batch2, vector_clock_token
Expand All @@ -295,7 +313,9 @@ def test_vector_clock_token(self):
)

# Paginating forwards should give the same results
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id1, vector_clock_token, prev_batch1
Expand All @@ -305,7 +325,9 @@ def test_vector_clock_token(self):
self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"])

request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id2, vector_clock_token, prev_batch2,
Expand Down
18 changes: 14 additions & 4 deletions tests/rest/admin/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from synapse.rest.client.v2_alpha import groups

from tests import unittest
from tests.server import FakeSite, make_request


class VersionTestCase(unittest.HomeserverTestCase):
Expand Down Expand Up @@ -222,8 +223,13 @@ def write_to(r):

def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it."""
request, channel = self.make_request(
"GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
Expand Down Expand Up @@ -287,7 +293,9 @@ def test_quarantine_media_by_id(self):
server_name, media_id = server_name_and_media_id.split("/")

# Attempt to access the media
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
server_name_and_media_id,
shorthand=False,
Expand Down Expand Up @@ -462,7 +470,9 @@ def test_cannot_quarantine_safe_media(self):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)

# Attempt to access each piece of media
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
server_and_media_id_2,
shorthand=False,
Expand Down
Loading