Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Prevent multiple device list updates from breaking a batch send (#5156)
Browse files Browse the repository at this point in the history
fixes #5153
  • Loading branch information
anoadragon453 authored and richvdh committed Jun 6, 2019
1 parent a118650 commit 2d1d7b7
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 31 deletions.
1 change: 1 addition & 0 deletions changelog.d/5156.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Prevent federation device list updates breaking when processing multiple updates at once.
5 changes: 3 additions & 2 deletions synapse/federation/sender/per_destination_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,10 @@ def _pop_pending_edus(self, limit):
@defer.inlineCallbacks
def _get_new_device_messages(self, limit):
last_device_list = self._last_device_list_stream_id
# Will return at most 20 entries

# Retrieve list of new device updates to send to the destination
now_stream_id, results = yield self._store.get_devices_by_remote(
self._destination, last_device_list
self._destination, last_device_list, limit=limit,
)
edus = [
Edu(
Expand Down
152 changes: 123 additions & 29 deletions synapse/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import logging

from six import iteritems, itervalues
from six import iteritems

from canonicaljson import json

Expand Down Expand Up @@ -72,67 +72,146 @@ def get_devices_by_user(self, user_id):

defer.returnValue({d["device_id"]: d for d in devices})

def get_devices_by_remote(self, destination, from_stream_id):
@defer.inlineCallbacks
def get_devices_by_remote(self, destination, from_stream_id, limit):
"""Get stream of updates to send to remote servers
Returns:
(int, list[dict]): current stream id and list of updates
Deferred[tuple[int, list[dict]]]:
current stream id (ie, the stream id of the last update included in the
response), and the list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()

has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])

return self.runInteraction(
defer.returnValue((now_stream_id, []))

# We retrieve n+1 devices from the list of outbound pokes where n is
# our outbound device update limit. We then check if the very last
# device has the same stream_id as the second-to-last device. If so,
# then we ignore all devices with that stream_id and only send the
# devices with a lower stream_id.
#
# If when culling the list we end up with no devices afterwards, we
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.runInteraction(
"get_devices_by_remote",
self._get_devices_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
limit + 1,
)

# Return an empty list if there are no updates
if not updates:
defer.returnValue((now_stream_id, []))

# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
stream_id_cutoff = updates[-1][2]
now_stream_id = stream_id_cutoff - 1
else:
stream_id_cutoff = None

# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
# (user_id, device_id) entries into a map, with the value being
# the max stream_id across each set of duplicate entries
#
# maps (user_id, device_id) -> stream_id
# as long as their stream_id does not match that of the last row
query_map = {}
for update in updates:
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
# Stop processing updates
break

key = (update[0], update[1])
query_map[key] = max(query_map.get(key, 0), update[2])

# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
# steam_id.

# That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map:
defer.returnValue((stream_id_cutoff, []))

results = yield self._get_device_update_edus_by_remote(
destination,
from_stream_id,
query_map,
)

defer.returnValue((now_stream_id, results))

def _get_devices_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination
Args:
txn (LoggingTransaction): The transaction to execute
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
now_stream_id (int): The maximum stream_id to filter updates by, inclusive
limit (int): Maximum number of device updates to return
Returns:
List: List of device updates
"""
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
ORDER BY stream_id
LIMIT ?
"""
txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit))

# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
return list(txn)

if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
@defer.inlineCallbacks
def _get_device_update_edus_by_remote(
self, destination, from_stream_id, query_map,
):
"""Returns a list of device update EDUs as well as E2EE keys
devices = self._get_e2e_device_keys_txn(
txn,
Args:
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): int]): Dictionary mapping
user_id/device_id to update stream_id
Returns:
List[Dict]: List of objects representing an device update EDU
"""
devices = yield self.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
)

prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""

results = []
for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
prev_id = yield self._get_last_device_update_for_remote_user(
destination, user_id, from_stream_id,
)
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
Expand All @@ -156,7 +235,22 @@ def _get_devices_by_remote_txn(

results.append(result)

return (now_stream_id, results)
defer.returnValue(results)

def _get_last_device_update_for_remote_user(
self, destination, user_id, from_stream_id,
):
def f(txn):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
return rows[0][0]

return self.runInteraction("get_last_device_update_for_remote_user", f)

def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
Expand Down
69 changes: 69 additions & 0 deletions tests/storage/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,75 @@ def test_get_devices_by_user(self):
res["device2"],
)

@defer.inlineCallbacks
def test_get_devices_by_remote(self):
device_ids = ["device_id1", "device_id2"]

# Add two device updates with a single stream_id
yield self.store.add_device_change_to_streams(
"user_id", device_ids, ["somehost"],
)

# Get all device updates ever meant for this remote
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"somehost", -1, limit=100,
)

# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)

@defer.inlineCallbacks
def test_get_devices_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments

# first add one device
device_ids1 = ["device_id0"]
yield self.store.add_device_change_to_streams(
"user_id", device_ids1, ["someotherhost"],
)

# then add 101
device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
yield self.store.add_device_change_to_streams(
"user_id", device_ids2, ["someotherhost"],
)

# then one more
device_ids3 = ["newdevice"]
yield self.store.add_device_change_to_streams(
"user_id", device_ids3, ["someotherhost"],
)

#
# now read them back.
#

# first we should get a single update
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"someotherhost", -1, limit=100,
)
self._check_devices_in_updates(device_ids1, device_updates)

# Then we should get an empty list back as the 101 devices broke the limit
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"someotherhost", now_stream_id, limit=100,
)
self.assertEqual(len(device_updates), 0)

# The 101 devices should've been cleared, so we should now just get one device
# update
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"someotherhost", now_stream_id, limit=100,
)
self._check_devices_in_updates(device_ids3, device_updates)

def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))

received_device_ids = {update["device_id"] for update in device_updates}
self.assertEqual(received_device_ids, set(expected_device_ids))

@defer.inlineCallbacks
def test_update_device(self):
yield self.store.store_device("user_id", "device_id", "display_name 1")
Expand Down

0 comments on commit 2d1d7b7

Please sign in to comment.