Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MSC3916 by adding _matrix/client/v1/media/download endpoint #17365

Merged
merged 13 commits into from
Jul 2, 2024
Merged
1 change: 1 addition & 0 deletions changelog.d/17365.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support [MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md) by adding _matrix/client/v1/media/download endpoint.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ ignore_missing_imports = True
# https://github.com/twisted/treq/pull/366
[mypy-treq.*]
ignore_missing_imports = True

[mypy-multipart.*]
ignore_missing_imports = True
18 changes: 16 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ pydantic = ">=1.7.4, <3"
# needed.
setuptools_rust = ">=1.3"

# This is used for parsing multipart responses
python-multipart = ">=0.0.9"

# Optional Dependencies
# ---------------------
Expand Down
46 changes: 46 additions & 0 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,6 +1871,52 @@ def filter_user_id(user_id: str) -> bool:

return filtered_statuses, filtered_failures

async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Union[
Tuple[int, Dict[bytes, List[bytes]], bytes],
Tuple[int, Dict[bytes, List[bytes]]],
]:
try:
return await self.transport_layer.federation_download_media(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise

logger.debug(
"Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path",
destination,
media_id,
)

return await self.transport_layer.download_media_v3(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)

async def download_media(
self,
destination: str,
Expand Down
25 changes: 23 additions & 2 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,6 @@ async def download_media_r0(
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"

return await self.client.get_file(
destination,
path,
Expand Down Expand Up @@ -852,7 +851,6 @@ async def download_media_v3(
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"

return await self.client.get_file(
destination,
path,
Expand All @@ -873,6 +871,29 @@ async def download_media_v3(
ip_address=ip_address,
)

async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
path = f"/_matrix/federation/v1/media/download/{media_id}"
anoadragon453 marked this conversation as resolved.
Show resolved Hide resolved
return await self.client.federation_get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
"timeout_ms": str(max_timeout_ms),
},
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)


def _create_path(federation_prefix: str, path: str, *args: str) -> str:
"""
Expand Down
9 changes: 3 additions & 6 deletions synapse/federation/transport/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet,
FederationMediaDownloadServlet,
FederationUnstableClientKeysClaimServlet,
FederationUnstableMediaDownloadServlet,
)
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
Expand Down Expand Up @@ -316,11 +316,8 @@ def register_servlets(
):
continue

if servletclass == FederationUnstableMediaDownloadServlet:
if (
not hs.config.server.enable_media_repo
or not hs.config.experimental.msc3916_authenticated_media_enabled
):
if servletclass == FederationMediaDownloadServlet:
if not hs.config.server.enable_media_repo:
continue

servletclass(
Expand Down
4 changes: 2 additions & 2 deletions synapse/federation/transport/server/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ async def new_func(
return None
if (
func.__self__.__class__.__name__ # type: ignore
== "FederationUnstableMediaDownloadServlet"
== "FederationMediaDownloadServlet"
):
response = await func(
origin, content, request, *args, **kwargs
Expand All @@ -374,7 +374,7 @@ async def new_func(
else:
if (
func.__self__.__class__.__name__ # type: ignore
== "FederationUnstableMediaDownloadServlet"
== "FederationMediaDownloadServlet"
):
response = await func(
origin, content, request, *args, **kwargs
Expand Down
5 changes: 2 additions & 3 deletions synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,15 +790,14 @@ async def on_POST(
return 200, {"account_statuses": statuses, "failures": failures}


class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
class FederationMediaDownloadServlet(BaseFederationServerServlet):
"""
Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns
a multipart/mixed response consisting of a JSON object and the requested media
item. This endpoint only returns local media.
"""

PATH = "/media/download/(?P<media_id>[^/]*)"
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916"
RATELIMIT = True

def __init__(
Expand Down Expand Up @@ -858,5 +857,5 @@ async def on_GET(
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
FederationAccountStatusServlet,
FederationUnstableMediaDownloadServlet,
FederationMediaDownloadServlet,
)
152 changes: 152 additions & 0 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
Union,
)

import attr
import multipart
import treq
from canonicaljson import encode_canonical_json
from netaddr import AddrFormatError, IPAddress, IPSet
Expand Down Expand Up @@ -1006,6 +1008,130 @@ def connectionLost(self, reason: Failure = connectionDone) -> None:
self._maybe_fail()


@attr.s(auto_attribs=True, slots=True)
class MultipartResponse:
"""
A small class to hold parsed values of a multipart response.
"""

json: bytes = b"{}"
length: Optional[int] = None
content_type: Optional[bytes] = None
disposition: Optional[bytes] = None
url: Optional[bytes] = None


class _MultipartParserProtocol(protocol.Protocol):
"""
Protocol to read and parse a MSC3916 multipart/mixed response
"""

transport: Optional[ITCPTransport] = None

def __init__(
self,
stream: ByteWriteable,
deferred: defer.Deferred,
boundary: str,
max_length: Optional[int],
) -> None:
self.stream = stream
self.deferred = deferred
self.boundary = boundary
self.max_length = max_length
self.parser = None
self.multipart_response = MultipartResponse()
self.has_redirect = False
self.in_json = False
self.json_done = False
self.file_length = 0
self.total_length = 0
self.in_disposition = False
self.in_content_type = False

def dataReceived(self, incoming_data: bytes) -> None:
if self.deferred.called:
return

# we don't have a parser yet, instantiate it
if not self.parser:

def on_header_field(data: bytes, start: int, end: int) -> None:
if data[start:end] == b"Location":
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
self.has_redirect = True
if data[start:end] == b"Content-Disposition":
self.in_disposition = True
if data[start:end] == b"Content-Type":
self.in_content_type = True

def on_header_value(data: bytes, start: int, end: int) -> None:
# the first header should be content-type for application/json
if not self.in_json and not self.json_done:
assert data[start:end] == b"application/json"
self.in_json = True
elif self.has_redirect:
self.multipart_response.url = data[start:end]
elif self.in_content_type:
self.multipart_response.content_type = data[start:end]
self.in_content_type = False
elif self.in_disposition:
self.multipart_response.disposition = data[start:end]
self.in_disposition = False

def on_part_data(data: bytes, start: int, end: int) -> None:
# we've seen json header but haven't written the json data
if self.in_json and not self.json_done:
self.multipart_response.json = data[start:end]
self.json_done = True
# we have a redirect header rather than a file, and have already captured it
elif self.has_redirect:
return
# otherwise we are in the file part
else:
logger.info("Writing multipart file data to stream")
try:
self.stream.write(data[start:end])
except Exception as e:
logger.warning(
f"Exception encountered writing file data to stream: {e}"
)
self.deferred.errback()
self.file_length += end - start

callbacks = {
"on_header_field": on_header_field,
"on_header_value": on_header_value,
"on_part_data": on_part_data,
}
self.parser = multipart.MultipartParser(self.boundary, callbacks)

self.total_length += len(incoming_data)
if self.max_length is not None and self.total_length >= self.max_length:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()

try:
self.parser.write(incoming_data) # type: ignore[attr-defined]
except Exception as e:
logger.warning(f"Exception writing to multipart parser: {e}")
self.deferred.errback()
return

def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return

if reason.check(ResponseDone):
self.multipart_response.length = self.file_length
self.deferred.callback(self.multipart_response)
else:
self.deferred.errback(reason)


class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""

Expand Down Expand Up @@ -1091,6 +1217,32 @@ def read_body_with_max_size(
return d


def read_multipart_response(
response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int]
) -> "defer.Deferred[MultipartResponse]":
"""
Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into
the stream passed in and returning a deferred resolving to a MultipartResponse

Args:
response: The HTTP response to read from.
stream: The file-object to write to.
boundary: the multipart/mixed boundary string
max_length: maximum allowable length of the response
"""
d: defer.Deferred[MultipartResponse] = defer.Deferred()

# If the Content-Length header gives a size larger than the maximum allowed
# size, do not bother downloading the body.
if max_length is not None and response.length != UNKNOWN_LENGTH:
if response.length > max_length:
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
return d

response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length))
return d


def encode_query_args(args: Optional[QueryParams]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.
Expand Down
Loading
Loading