Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert the SimpleHttpClient to async. #8016

Merged
merged 1 commit into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8016.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
2 changes: 1 addition & 1 deletion synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _get():
urllib.parse.quote(protocol),
)
try:
info = yield self.get_json(uri, {})
info = yield defer.ensureDeferred(self.get_json(uri, {}))

if not _is_valid_3pe_metadata(info):
logger.warning(
Expand Down
55 changes: 24 additions & 31 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,7 @@ def __getattr__(_self, attr):
ip_blacklist=self._ip_blacklist,
)

@defer.inlineCallbacks
def request(self, method, uri, data=None, headers=None):
async def request(self, method, uri, data=None, headers=None):
"""
Args:
method (str): HTTP method to use.
Expand Down Expand Up @@ -330,7 +329,7 @@ def request(self, method, uri, data=None, headers=None):
self.hs.get_reactor(),
cancelled_to_request_timed_out_error,
)
response = yield make_deferred_yieldable(request_deferred)
response = await make_deferred_yieldable(request_deferred)

incoming_responses_counter.labels(method, response.code).inc()
logger.info(
Expand All @@ -353,8 +352,7 @@ def request(self, method, uri, data=None, headers=None):
set_tag("error_reason", e.args[0])
raise

@defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}, headers=None):
async def post_urlencoded_get_json(self, uri, args={}, headers=None):
"""
Args:
uri (str):
Expand All @@ -363,7 +361,7 @@ def post_urlencoded_get_json(self, uri, args={}, headers=None):
header name to a list of values for that header

Returns:
Deferred[object]: parsed json
object: parsed json

Raises:
HttpResponseException: On a non-2xx HTTP response.
Expand All @@ -386,11 +384,11 @@ def post_urlencoded_get_json(self, uri, args={}, headers=None):
if headers:
actual_headers.update(headers)

response = yield self.request(
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
)

body = yield make_deferred_yieldable(readBody(response))
body = await make_deferred_yieldable(readBody(response))

if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
Expand All @@ -399,8 +397,7 @@ def post_urlencoded_get_json(self, uri, args={}, headers=None):
response.code, response.phrase.decode("ascii", errors="replace"), body
)

@defer.inlineCallbacks
def post_json_get_json(self, uri, post_json, headers=None):
async def post_json_get_json(self, uri, post_json, headers=None):
"""

Args:
Expand All @@ -410,7 +407,7 @@ def post_json_get_json(self, uri, post_json, headers=None):
header name to a list of values for that header

Returns:
Deferred[object]: parsed json
object: parsed json

Raises:
HttpResponseException: On a non-2xx HTTP response.
Expand All @@ -429,11 +426,11 @@ def post_json_get_json(self, uri, post_json, headers=None):
if headers:
actual_headers.update(headers)

response = yield self.request(
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
)

body = yield make_deferred_yieldable(readBody(response))
body = await make_deferred_yieldable(readBody(response))

if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
Expand All @@ -442,8 +439,7 @@ def post_json_get_json(self, uri, post_json, headers=None):
response.code, response.phrase.decode("ascii", errors="replace"), body
)

@defer.inlineCallbacks
def get_json(self, uri, args={}, headers=None):
async def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.

Args:
Expand All @@ -455,7 +451,7 @@ def get_json(self, uri, args={}, headers=None):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Raises:
HttpResponseException On a non-2xx HTTP response.
Expand All @@ -466,11 +462,10 @@ def get_json(self, uri, args={}, headers=None):
if headers:
actual_headers.update(headers)

body = yield self.get_raw(uri, args, headers=headers)
body = await self.get_raw(uri, args, headers=headers)
return json.loads(body.decode("utf-8"))

@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}, headers=None):
async def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.

Args:
Expand All @@ -483,7 +478,7 @@ def put_json(self, uri, json_body, args={}, headers=None):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Raises:
HttpResponseException On a non-2xx HTTP response.
Expand All @@ -504,11 +499,11 @@ def put_json(self, uri, json_body, args={}, headers=None):
if headers:
actual_headers.update(headers)

response = yield self.request(
response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
)

body = yield make_deferred_yieldable(readBody(response))
body = await make_deferred_yieldable(readBody(response))

if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
Expand All @@ -517,8 +512,7 @@ def put_json(self, uri, json_body, args={}, headers=None):
response.code, response.phrase.decode("ascii", errors="replace"), body
)

@defer.inlineCallbacks
def get_raw(self, uri, args={}, headers=None):
async def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.

Args:
Expand All @@ -530,7 +524,7 @@ def get_raw(self, uri, args={}, headers=None):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as bytes.
Raises:
HttpResponseException on a non-2xx HTTP response.
Expand All @@ -543,9 +537,9 @@ def get_raw(self, uri, args={}, headers=None):
if headers:
actual_headers.update(headers)

response = yield self.request("GET", uri, headers=Headers(actual_headers))
response = await self.request("GET", uri, headers=Headers(actual_headers))

body = yield make_deferred_yieldable(readBody(response))
body = await make_deferred_yieldable(readBody(response))

if 200 <= response.code < 300:
return body
Expand All @@ -557,8 +551,7 @@ def get_raw(self, uri, args={}, headers=None):
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.

@defer.inlineCallbacks
def get_file(self, url, output_stream, max_size=None, headers=None):
async def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
Expand All @@ -574,7 +567,7 @@ def get_file(self, url, output_stream, max_size=None, headers=None):
if headers:
actual_headers.update(headers)

response = yield self.request("GET", url, headers=Headers(actual_headers))
response = await self.request("GET", url, headers=Headers(actual_headers))

resp_headers = dict(response.headers.getAllRawHeaders())

Expand All @@ -598,7 +591,7 @@ def get_file(self, url, output_stream, max_size=None, headers=None):
# straight back in again

try:
length = yield make_deferred_yieldable(
length = await make_deferred_yieldable(
_readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
Expand Down