Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix race for concurrent downloads of remote media. #8682

Merged
merged 7 commits into from
Oct 30, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8682.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories.
165 changes: 105 additions & 60 deletions synapse/rest/media/v1/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,12 @@ async def _get_remote_media_impl(
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
# one.
if media_info:
file_id = media_info["filesystem_id"]
else:
file_id = random_string(24)

file_info = FileInfo(server_name, file_id)

# If we have an entry in the DB, try and look for it
if media_info:
file_id = media_info["filesystem_id"]
file_info = FileInfo(server_name, file_id)

if media_info["quarantined_by"]:
logger.info("Media is quarantined")
raise NotFoundError()
Expand All @@ -324,14 +321,34 @@ async def _get_remote_media_impl(

# Failed to find the file anywhere, lets download it.

media_info = await self._download_remote_file(server_name, media_id, file_id)
try:
media_info = await self._download_remote_file(server_name, media_id,)
except SynapseError:
raise
except Exception as e:
# An exception may be because we downloaded media in another
# process, so let's check if we magically have the media.
media_info = await self.store.get_cached_remote_media(server_name, media_id)
if not media_info:
raise e

file_id = media_info["filesystem_id"]
file_info = FileInfo(server_name, file_id)

# We generate thumbnails event if another process downloaded the media
# as a) it's conceivable that the other download request dies before it
# generates thumbnails, but mainly b) we want toe be sure the thumbnails
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
# have finished being generated before responding to the client,
# otherwise they'll request thumbnails and get a 404 if they're not
# ready yet.
await self._generate_thumbnails(
server_name, media_id, file_id, media_info["media_type"]
)
clokep marked this conversation as resolved.
Show resolved Hide resolved

responder = await self.media_storage.fetch_media(file_info)
return responder, media_info

async def _download_remote_file(
self, server_name: str, media_id: str, file_id: str
) -> dict:
async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.

Expand All @@ -346,6 +363,8 @@ async def _download_remote_file(
The media info of the file.
"""

file_id = random_string(24)

file_info = FileInfo(server_name=server_name, file_id=file_id)

with self.media_storage.store_into_file(file_info) as (f, fname, finish):
Expand Down Expand Up @@ -401,22 +420,32 @@ async def _download_remote_file(

await finish()

media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()

# Multiple remote media download requests can race (when using
# multiple media repos), so this may throw a violation constraint
# exception. If it does we'll delete the newly downloaded file from
# disk (as we're in the ctx manager).
#
# However: we've already called `finish()` so we may have also
# written to the storage providers. This is preferable to the
# alternative where we call `finish()` *after* this, where we could
# end up having an entry in the DB but fail to write the files to
# the storage providers.
Comment on lines +432 to +436
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meaning we might still end up with files that are useless?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should file an issue to fix that though (e.g adding a delete_file function to the storage providers and calling that).

await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)

logger.info("Stored remote media in file %r", fname)

await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)

media_info = {
"media_type": media_type,
"media_length": length,
Expand All @@ -425,8 +454,6 @@ async def _download_remote_file(
"filesystem_id": file_id,
}

await self._generate_thumbnails(server_name, media_id, file_id, media_type)

return media_info

def _get_thumbnail_requirements(self, media_type):
Expand Down Expand Up @@ -692,42 +719,60 @@ async def _generate_thumbnails(
if not t_byte_source:
continue

try:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
url_cache=url_cache,
)

output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
t_byte_source.close()

t_len = os.path.getsize(output_path)
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
url_cache=url_cache,
)

# Write to database
if server_name:
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
t_width,
t_height,
t_type,
t_method,
t_len,
)
else:
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
await self.media_storage.write_to_file(t_byte_source, f)
await finish()
finally:
t_byte_source.close()

t_len = os.path.getsize(fname)

# Write to database
if server_name:
# Multiple remote media download requests can race (when
# using multiple media repos), so this may throw a violation
# constraint exception. If it does we'll delete the newly
# generated thumbnail from disk (as we're in the ctx
# manager).
#
# However: we've already called `finish()` so we may have
# also written to the storage providers. This is preferable
# to the alternative where we call `finish()` *after* this,
# where we could end up having an entry in the DB but fail
# to write the files to the storage providers.
try:
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
t_width,
t_height,
t_type,
t_method,
t_len,
)
except Exception as e:
thumbnail_exists = await self.store.get_remote_media_thumbnail(
server_name, media_id, t_width, t_height, t_type,
)
if not thumbnail_exists:
raise e
else:
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)

return {"width": m_width, "height": m_height}

Expand Down
30 changes: 20 additions & 10 deletions synapse/rest/media/v1/media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
storage_providers: Sequence["StorageProviderWrapper"],
):
self.hs = hs
self.reactor = hs.get_reactor()
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
Expand All @@ -70,13 +71,16 @@ async def store_file(self, source: IO, file_info: FileInfo) -> str:

with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
await defer_to_thread(
self.hs.get_reactor(), _write_file_synchronously, source, f
)
await self.write_to_file(source, f)
await finish_cb()

return fname

async def write_to_file(self, source: IO, output: IO):
"""Asynchronously write the `source` to `output`.
"""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)

@contextlib.contextmanager
def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
Expand Down Expand Up @@ -112,14 +116,20 @@ def store_into_file(self, file_info: FileInfo):

finished_called = [False]

async def finish():
for provider in self.storage_providers:
await provider.store_file(path, file_info)

finished_called[0] = True

try:
with open(fname, "wb") as f:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, this being a context manager here is a bit weird when we don't really write inside of it...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we do in that we write during the yield, as the function is a context manager?


async def finish():
# Ensure that all writes have been flushed and close the
# file.
f.flush()
f.close()

for provider in self.storage_providers:
await provider.store_file(path, file_info)

finished_called[0] = True

yield f, fname, finish
except Exception:
try:
Expand Down Expand Up @@ -210,7 +220,7 @@ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.hs.get_reactor()
open(local_path, "wb"), self.reactor
)
await res.write_to_consumer(consumer)
await consumer.wait()
Expand Down
27 changes: 27 additions & 0 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,33 @@ async def get_remote_media_thumbnails(
desc="get_remote_media_thumbnails",
)

async def get_remote_media_thumbnail(
self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
) -> Optional[Dict[str, Any]]:
"""Fetch the thumbnail info of given width, height and type.
"""

return await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails",
keyvalues={
"media_origin": origin,
"media_id": media_id,
"thumbnail_width": t_width,
"thumbnail_height": t_height,
"thumbnail_type": t_type,
},
retcols=(
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
"filesystem_id",
),
allow_none=True,
desc="get_remote_media_thumbnail",
)

async def store_remote_media_thumbnail(
self,
origin,
Expand Down
Loading