Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix race for concurrent downloads of remote media. (#8682)
Browse files Browse the repository at this point in the history
Fixes #6755
  • Loading branch information
erikjohnston authored Oct 30, 2020
1 parent 4504151 commit 46f4be9
Show file tree
Hide file tree
Showing 6 changed files with 431 additions and 71 deletions.
1 change: 1 addition & 0 deletions changelog.d/8682.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories.
165 changes: 105 additions & 60 deletions synapse/rest/media/v1/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,12 @@ async def _get_remote_media_impl(
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
# one.
if media_info:
file_id = media_info["filesystem_id"]
else:
file_id = random_string(24)

file_info = FileInfo(server_name, file_id)

# If we have an entry in the DB, try and look for it
if media_info:
file_id = media_info["filesystem_id"]
file_info = FileInfo(server_name, file_id)

if media_info["quarantined_by"]:
logger.info("Media is quarantined")
raise NotFoundError()
Expand All @@ -324,14 +321,34 @@ async def _get_remote_media_impl(

# Failed to find the file anywhere, lets download it.

media_info = await self._download_remote_file(server_name, media_id, file_id)
try:
media_info = await self._download_remote_file(server_name, media_id,)
except SynapseError:
raise
except Exception as e:
# An exception may be because we downloaded media in another
# process, so let's check if we magically have the media.
media_info = await self.store.get_cached_remote_media(server_name, media_id)
if not media_info:
raise e

file_id = media_info["filesystem_id"]
file_info = FileInfo(server_name, file_id)

# We generate thumbnails even if another process downloaded the media
# as a) it's conceivable that the other download request dies before it
# generates thumbnails, but mainly b) we want to be sure the thumbnails
# have finished being generated before responding to the client,
# otherwise they'll request thumbnails and get a 404 if they're not
# ready yet.
await self._generate_thumbnails(
server_name, media_id, file_id, media_info["media_type"]
)

responder = await self.media_storage.fetch_media(file_info)
return responder, media_info

async def _download_remote_file(
self, server_name: str, media_id: str, file_id: str
) -> dict:
async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
Expand All @@ -346,6 +363,8 @@ async def _download_remote_file(
The media info of the file.
"""

file_id = random_string(24)

file_info = FileInfo(server_name=server_name, file_id=file_id)

with self.media_storage.store_into_file(file_info) as (f, fname, finish):
Expand Down Expand Up @@ -401,22 +420,32 @@ async def _download_remote_file(

await finish()

media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()

# Multiple remote media download requests can race (when using
# multiple media repos), so this may throw a violation constraint
# exception. If it does we'll delete the newly downloaded file from
# disk (as we're in the ctx manager).
#
# However: we've already called `finish()` so we may have also
# written to the storage providers. This is preferable to the
# alternative where we call `finish()` *after* this, where we could
# end up having an entry in the DB but fail to write the files to
# the storage providers.
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)

logger.info("Stored remote media in file %r", fname)

await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)

media_info = {
"media_type": media_type,
"media_length": length,
Expand All @@ -425,8 +454,6 @@ async def _download_remote_file(
"filesystem_id": file_id,
}

await self._generate_thumbnails(server_name, media_id, file_id, media_type)

return media_info

def _get_thumbnail_requirements(self, media_type):
Expand Down Expand Up @@ -692,42 +719,60 @@ async def _generate_thumbnails(
if not t_byte_source:
continue

try:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
url_cache=url_cache,
)

output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
t_byte_source.close()

t_len = os.path.getsize(output_path)
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
url_cache=url_cache,
)

# Write to database
if server_name:
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
t_width,
t_height,
t_type,
t_method,
t_len,
)
else:
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
await self.media_storage.write_to_file(t_byte_source, f)
await finish()
finally:
t_byte_source.close()

t_len = os.path.getsize(fname)

# Write to database
if server_name:
# Multiple remote media download requests can race (when
# using multiple media repos), so this may throw a violation
# constraint exception. If it does we'll delete the newly
# generated thumbnail from disk (as we're in the ctx
# manager).
#
# However: we've already called `finish()` so we may have
# also written to the storage providers. This is preferable
# to the alternative where we call `finish()` *after* this,
# where we could end up having an entry in the DB but fail
# to write the files to the storage providers.
try:
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
t_width,
t_height,
t_type,
t_method,
t_len,
)
except Exception as e:
thumbnail_exists = await self.store.get_remote_media_thumbnail(
server_name, media_id, t_width, t_height, t_type,
)
if not thumbnail_exists:
raise e
else:
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)

return {"width": m_width, "height": m_height}

Expand Down
30 changes: 20 additions & 10 deletions synapse/rest/media/v1/media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
storage_providers: Sequence["StorageProviderWrapper"],
):
self.hs = hs
self.reactor = hs.get_reactor()
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
Expand All @@ -70,13 +71,16 @@ async def store_file(self, source: IO, file_info: FileInfo) -> str:

with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
await defer_to_thread(
self.hs.get_reactor(), _write_file_synchronously, source, f
)
await self.write_to_file(source, f)
await finish_cb()

return fname

async def write_to_file(self, source: IO, output: IO):
"""Asynchronously write the `source` to `output`.
"""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)

@contextlib.contextmanager
def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
Expand Down Expand Up @@ -112,14 +116,20 @@ def store_into_file(self, file_info: FileInfo):

finished_called = [False]

async def finish():
for provider in self.storage_providers:
await provider.store_file(path, file_info)

finished_called[0] = True

try:
with open(fname, "wb") as f:

async def finish():
# Ensure that all writes have been flushed and close the
# file.
f.flush()
f.close()

for provider in self.storage_providers:
await provider.store_file(path, file_info)

finished_called[0] = True

yield f, fname, finish
except Exception:
try:
Expand Down Expand Up @@ -210,7 +220,7 @@ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.hs.get_reactor()
open(local_path, "wb"), self.reactor
)
await res.write_to_consumer(consumer)
await consumer.wait()
Expand Down
27 changes: 27 additions & 0 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,33 @@ async def get_remote_media_thumbnails(
desc="get_remote_media_thumbnails",
)

async def get_remote_media_thumbnail(
self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
) -> Optional[Dict[str, Any]]:
"""Fetch the thumbnail info of given width, height and type.
"""

return await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails",
keyvalues={
"media_origin": origin,
"media_id": media_id,
"thumbnail_width": t_width,
"thumbnail_height": t_height,
"thumbnail_type": t_type,
},
retcols=(
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
"filesystem_id",
),
allow_none=True,
desc="get_remote_media_thumbnail",
)

async def store_remote_media_thumbnail(
self,
origin,
Expand Down
Loading

0 comments on commit 46f4be9

Please sign in to comment.