Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Request & follow redirects for /media/v3/download #16701

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TYPE_CHECKING,
AbstractSet,
Awaitable,
BinaryIO,
Callable,
Collection,
Container,
Expand Down Expand Up @@ -1862,6 +1863,41 @@ def filter_user_id(user_id: str) -> bool:

return filtered_statuses, filtered_failures

async def download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
try:
return await self.transport_layer.download_media_v3(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the r0 endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise

logger.debug(
"Couldn't download media with the v3 API, falling back to the r0 API"
)
clokep marked this conversation as resolved.
Show resolved Hide resolved

return await self.transport_layer.download_media_r0(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
)


@attr.s(frozen=True, slots=True, auto_attribs=True)
class TimestampToEventResponse:
Expand Down
49 changes: 49 additions & 0 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Collection,
Dict,
Expand Down Expand Up @@ -804,6 +805,54 @@ async def get_account_status(
destination=destination, path=path, data={"user_ids": user_ids}
)

async def download_media_r0(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"

return await self.client.get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
)

async def download_media_v3(
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"

return await self.client.get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
)


def _create_path(federation_prefix: str, path: str, *args: str) -> str:
"""
Expand Down
17 changes: 4 additions & 13 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class MediaRepository:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.client = hs.get_federation_http_client()
self.client = hs.get_federation_client()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastores().main
Expand Down Expand Up @@ -644,22 +644,13 @@ async def _download_remote_file(
file_info = FileInfo(server_name=server_name, file_id=file_id)

with self.media_storage.store_into_file(file_info) as (f, fname, finish):
request_path = "/".join(
("/_matrix/media/r0/download", server_name, media_id)
)
try:
length, headers = await self.client.get_file(
length, headers = await self.client.download_media(
server_name,
request_path,
media_id,
output_stream=f,
max_size=self.max_upload_size,
args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
max_timeout_ms=max_timeout_ms,
)
except RequestSendFailed as e:
logger.warning(
Expand Down
58 changes: 55 additions & 3 deletions tests/media/test_media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@

from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource

from synapse.api.errors import Codes
from synapse.api.errors import Codes, HttpResponseException
from synapse.events import EventBase
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
Expand Down Expand Up @@ -257,10 +258,15 @@ def write_to(
output_stream.write(data)
return response

def write_err(f: Failure) -> Failure:
f.trap(HttpResponseException)
output_stream.write(f.value.response)
return f
Comment on lines +262 to +265
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one was new to me: https://docs.twisted.org/en/stable/api/twisted.python.failure.Failure.html#trap

TL;DR the trap call is a no-op if f contains an HTTPResponseException; otherwise the trap raises immediately so that the next errback can handle this Failure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd missed that this was in test code though. Why do we need to add this as an errback all of a sudden?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add it because we know call errback sometimes on the list of Deferreds so that we can resolve a request with an error instead of a response.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh, I think I see: we never called errback on the mock until now!


d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args))
# Note that this callback changes the value held by d.
d_after_callback = d.addCallback(write_to)
d_after_callback = d.addCallbacks(write_to, write_err)
return make_deferred_yieldable(d_after_callback)

# Mock out the homeserver's MatrixFederationHttpClient
Expand Down Expand Up @@ -316,7 +322,7 @@ def _req(
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id
)
self.assertEqual(
self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"}
Expand Down Expand Up @@ -671,6 +677,52 @@ def test_cross_origin_resource_policy_header(self) -> None:
[b"cross-origin"],
)

def test_unknown_v3_endpoint(self) -> None:
"""
If the v3 endpoint fails, try the r0 one.
"""
channel = self.make_request(
"GET",
f"/_matrix/media/v3/download/{self.media_id}",
shorthand=False,
await_result=False,
)
self.pump()

# We've made one fetch, to example.com, using the media URL, and asking
# the other server not to do a remote fetch
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id
)

# The result which says the endpoint is unknown.
unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
self.fetches[0][0].errback(
HttpResponseException(404, "NOT FOUND", unknown_endpoint)
)

self.pump()

# There should now be another request to the r0 URL.
self.assertEqual(len(self.fetches), 2)
self.assertEqual(self.fetches[1][1], "example.com")
self.assertEqual(
self.fetches[1][2], f"/_matrix/media/r0/download/{self.media_id}"
)

headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
}

self.fetches[1][0].callback(
(self.test_image.data, (len(self.test_image.data), headers))
)

self.pump()
self.assertEqual(channel.code, 200)


class TestSpamCheckerLegacy:
"""A spam checker module that rejects all media that includes the bytes
Expand Down