Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Allow local media to be marked as safe from being quarantined #7718

Merged
merged 16 commits into from
Jun 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7718.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Media can now be marked as safe from quarantined.
1 change: 1 addition & 0 deletions scripts/synapse_port_db
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = {
"account_validity": ["email_sent"],
"redactions": ["have_censored"],
"room_stats_state": ["is_federatable"],
"local_media_repository": ["safe_from_quarantine"],
}


Expand Down
9 changes: 9 additions & 0 deletions synapse/storage/data_stores/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def store_local_media(
desc="store_local_media",
)

def mark_local_media_as_safe(self, media_id: str):
"""Mark a local media as safe from quarantining."""
return self.db.simple_update_one(
table="local_media_repository",
keyvalues={"media_id": media_id},
updatevalues={"safe_from_quarantine": True},
desc="mark_local_media_as_safe",
)

def get_url_cache(self, url, ts):
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
Expand Down
42 changes: 7 additions & 35 deletions synapse/storage/data_stores/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,36 +626,10 @@ def quarantine_media_ids_in_room(self, room_id, quarantined_by):

def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0

# Now update all the tables to set the quarantined_by flag

txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""",
((quarantined_by, media_id) for media_id in local_mxcs),
)

txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin = ? AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
),
return self._quarantine_media_txn(
txn, local_mxcs, remote_mxcs, quarantined_by
)

total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)

return total_media_quarantined

return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
Expand Down Expand Up @@ -805,17 +779,17 @@ def _quarantine_media_txn(
Returns:
The total number of media items quarantined
"""
total_media_quarantined = 0

# Update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
WHERE media_id = ? AND safe_from_quarantine = ?
""",
((quarantined_by, media_id) for media_id in local_mxcs),
((quarantined_by, media_id, False) for media_id in local_mxcs),
)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0

txn.executemany(
"""
Expand All @@ -825,9 +799,7 @@ def _quarantine_media_txn(
""",
((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
)

total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0

return total_media_quarantined

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

-- The local_media_repository should have files which do not get quarantined,
-- e.g. files from sticker packs.
ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE;
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

-- The local_media_repository should have files which do not get quarantined,
-- e.g. files from sticker packs.
ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0;
137 changes: 65 additions & 72 deletions tests/rest/admin/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,24 @@ def write_to(r):

return hs

def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it."""
request, channel = self.make_request(
"GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)

# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id
),
)

def test_quarantine_media_requires_admin(self):
self.register_user("nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("nonadmin", "pass")
Expand Down Expand Up @@ -292,24 +310,7 @@ def test_quarantine_media_by_id(self):
self.assertEqual(200, int(channel.code), msg=channel.result["body"])

# Attempt to access the media
request, channel = self.make_request(
"GET",
server_name_and_media_id,
shorthand=False,
access_token=admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)

# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_name_and_media_id
),
)
self._ensure_quarantined(admin_user_tok, server_name_and_media_id)

def test_quarantine_all_media_in_room(self, override_url_template=None):
self.register_user("room_admin", "pass", admin=True)
Expand Down Expand Up @@ -371,45 +372,10 @@ def test_quarantine_all_media_in_room(self, override_url_template=None):
server_and_media_id_2 = mxc_2[6:]

# Test that we cannot download any of the media anymore
request, channel = self.make_request(
"GET",
server_and_media_id_1,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)

# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_1
),
)

request, channel = self.make_request(
"GET",
server_and_media_id_2,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)

# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_2
),
)
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)

def test_quaraantine_all_media_in_room_deprecated_api_path(self):
def test_quarantine_all_media_in_room_deprecated_api_path(self):
# Perform the above test with the deprecated API path
self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")

Expand Down Expand Up @@ -449,25 +415,52 @@ def test_quarantine_all_media_by_user(self):
)

# Attempt to access each piece of media
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)

def test_cannot_quarantine_safe_media(self):
self.register_user("user_admin", "pass", admin=True)
admin_user_tok = self.login("user_admin", "pass")

non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("user_nonadmin", "pass")

# Upload some media
response_1 = self.helper.upload_media(
self.upload_resource, self.image_data, tok=non_admin_user_tok
)
response_2 = self.helper.upload_media(
self.upload_resource, self.image_data, tok=non_admin_user_tok
)

# Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:]
server_and_media_id_2 = response_2["content_uri"][6:]

# Mark the second item as safe from quarantine.
_, media_id_2 = server_and_media_id_2.split("/")
self.get_success(self.store.mark_local_media_as_safe(media_id_2))

# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
non_admin_user
)
request, channel = self.make_request(
"GET",
server_and_media_id_1,
shorthand=False,
access_token=non_admin_user_tok,
"POST", url.encode("ascii"), access_token=admin_user_tok,
)
request.render(self.download_resource)
self.render(request)
self.pump(1.0)

# Should be quarantined
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_1,
),
json.loads(channel.result["body"].decode("utf-8")),
{"num_quarantined": 1},
"Expected 1 quarantined item",
)

# Attempt to access each piece of media, the first should fail, the
# second should succeed.
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)

# Attempt to access each piece of media
request, channel = self.make_request(
"GET",
Expand All @@ -478,12 +471,12 @@ def test_quarantine_all_media_by_user(self):
request.render(self.download_resource)
self.pump(1.0)

# Should be quarantined
# Shouldn't be quarantined
self.assertEqual(
404,
200,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
"Expected to receive a 200 on accessing not-quarantined media: %s"
% server_and_media_id_2
),
)