Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add tests for outbound device pokes
Browse files Browse the repository at this point in the history
  • Loading branch information
richvdh committed Mar 27, 2020
1 parent 28d9d6e commit 665630f
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 3 deletions.
1 change: 1 addition & 0 deletions changelog.d/7157.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add tests for outbound device pokes.
303 changes: 300 additions & 3 deletions tests/federation/test_federation_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from mock import Mock

from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey

from twisted.internet import defer

from synapse.types import ReadReceipt
from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.types import JsonDict, ReadReceipt

from tests.unittest import HomeserverTestCase, override_config


class FederationSenderTestCases(HomeserverTestCase):
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
return super(FederationSenderTestCases, self).setup_test_homeserver(
return self.setup_test_homeserver(
state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)
Expand Down Expand Up @@ -147,3 +153,294 @@ def test_send_receipts_with_backoff(self):
}
],
)


class FederationSenderDevicesTestCases(HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
]

def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)

def default_config(self):
c = super().default_config()
c["send_federation"] = True
return c

def prepare(self, reactor, clock, hs):
# stub out get_current_hosts_in_room
mock_state_handler = hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]

# stub out get_users_who_share_room_with_user so that it claims that
# `@user2:host2` is in the room
def get_users_who_share_room_with_user(user_id):
return defer.succeed({"@user2:host2"})

hs.get_datastore().get_users_who_share_room_with_user = (
get_users_who_share_room_with_user
)

# whenever send_transaction is called, record the edu data
self.edus = []
self.hs.get_federation_transport_client().send_transaction.side_effect = (
self.record_transaction
)

def record_transaction(self, txn, json_cb):
data = json_cb()
self.edus.extend(data["edus"])
return defer.succeed({})

def test_send_device_updates(self):
"""Basic case: each device update should result in an EDU"""
# create a device
u1 = self.register_user("user", "pass")
self.login(u1, "pass", device_id="D1")

# expect one edu
self.assertEqual(len(self.edus), 1)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)

# a second call should produce no new device EDUs
self.hs.get_federation_sender().send_device_messages("host2")
self.pump()
self.assertEqual(self.edus, [])

# a second device
self.login("user", "pass", device_id="D2")

self.assertEqual(len(self.edus), 1)
self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)

def test_upload_signatures(self):
"""Uploading signatures on some devices should produce updates for that user"""

e2e_handler = self.hs.get_e2e_keys_handler()

# register two devices
u1 = self.register_user("user", "pass")
self.login(u1, "pass", device_id="D1")
self.login(u1, "pass", device_id="D2")

# expect two edus
self.assertEqual(len(self.edus), 2)
stream_id = None
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)

# upload signing keys for each device
device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1")
device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2")

# expect two more edus
self.assertEqual(len(self.edus), 2)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)

# upload master key and self-signing key
master_signing_key = generate_self_id_key()
master_key = {
"user_id": u1,
"usage": ["master"],
"keys": {key_id(master_signing_key): encode_pubkey(master_signing_key)},
}

# private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8
selfsigning_signing_key = generate_self_id_key()
selfsigning_key = {
"user_id": u1,
"usage": ["self_signing"],
"keys": {
key_id(selfsigning_signing_key): encode_pubkey(selfsigning_signing_key)
},
}
sign.sign_json(selfsigning_key, u1, master_signing_key)

cross_signing_keys = {
"master_key": master_key,
"self_signing_key": selfsigning_key,
}

self.get_success(
e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys)
)

# expect signing key update edu
self.assertEqual(len(self.edus), 1)
self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")

# sign the devices
d1_json = build_device_dict(u1, "D1", device1_signing_key)
sign.sign_json(d1_json, u1, selfsigning_signing_key)
d2_json = build_device_dict(u1, "D2", device2_signing_key)
sign.sign_json(d2_json, u1, selfsigning_signing_key)

ret = self.get_success(
e2e_handler.upload_signatures_for_device_keys(
u1, {u1: {"D1": d1_json, "D2": d2_json}},
)
)
self.assertEqual(ret["failures"], {})

# expect two edus, in one or two transactions. We don't know what order the
# devices will be updated.
self.assertEqual(len(self.edus), 2)
stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142
for edu in self.edus:
self.assertEqual(edu["edu_type"], "m.device_list_update")
c = edu["content"]
if stream_id is not None:
self.assertEqual(c["prev_id"], [stream_id])
stream_id = c["stream_id"]
devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2"}, devices)

def test_delete_devices(self):
"""If devices are deleted, that should result in EDUs too"""

# create devices
u1 = self.register_user("user", "pass")
self.login("user", "pass", device_id="D1")
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")

# expect three edus
self.assertEqual(len(self.edus), 3)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)

# delete them again
self.get_success(
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)

# expect three edus, in an unknown order
self.assertEqual(len(self.edus), 3)
for edu in self.edus:
self.assertEqual(edu["edu_type"], "m.device_list_update")
c = edu["content"]
self.assertGreaterEqual(
c.items(),
{"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(),
)
stream_id = c["stream_id"]
devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2", "D3"}, devices)

def test_unreachable_server(self):
"""If the destination server is unreachable, all the updates should get sent on
recovery
"""
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")

# create devices
u1 = self.register_user("user", "pass")
self.login("user", "pass", device_id="D1")
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")

# delete them again
self.get_success(
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)

self.assertGreaterEqual(mock_send_txn.call_count, 4)

# recover the server
mock_send_txn.side_effect = self.record_transaction
self.hs.get_federation_sender().send_device_messages("host2")
self.pump()

# for each device, there should be a single update
self.assertEqual(len(self.edus), 3)
stream_id = None
for edu in self.edus:
self.assertEqual(edu["edu_type"], "m.device_list_update")
c = edu["content"]
self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else [])
stream_id = c["stream_id"]
devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2", "D3"}, devices)

def check_device_update_edu(
self,
edu: JsonDict,
user_id: str,
device_id: str,
prev_stream_id: Optional[int],
) -> int:
"""Check that the given EDU is an update for the given device
Returns the stream_id.
"""
self.assertEqual(edu["edu_type"], "m.device_list_update")
content = edu["content"]

expected = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_stream_id] if prev_stream_id is not None else [],
}

self.assertLessEqual(expected.items(), content.items())
return content["stream_id"]

def check_signing_key_update_txn(self, txn: JsonDict,) -> None:
"""Check that the txn has an EDU with a signing key update.
"""
edus = txn["edus"]
self.assertEqual(len(edus), 1)

def generate_and_upload_device_signing_key(
self, user_id: str, device_id: str
) -> SigningKey:
"""Generate a signing keypair for the given device, and upload it"""
sk = key.generate_signing_key(device_id)

device_dict = build_device_dict(user_id, device_id, sk)

self.get_success(
self.hs.get_e2e_keys_handler().upload_keys_for_user(
user_id, device_id, {"device_keys": device_dict},
)
)
return sk


def generate_self_id_key() -> SigningKey:
"""generate a signing key whose version is its public key
... as used by the cross-signing-keys.
"""
k = key.generate_signing_key("x")
k.version = encode_pubkey(k)
return k


def key_id(k: BaseKey) -> str:
return "%s:%s" % (k.alg, k.version)


def encode_pubkey(sk: SigningKey) -> str:
"""Encode the public key corresponding to the given signing key as base64"""
return key.encode_verify_key_base64(key.get_verify_key(sk))


def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
"""Build a dict representing the given device"""
return {
"user_id": user_id,
"device_id": device_id,
"algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
"keys": {
"curve25519:" + device_id: "curve25519+key",
key_id(sk): encode_pubkey(sk),
},
}
1 change: 1 addition & 0 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def register_user(self, username, password, admin=False):
"password": password,
"admin": admin,
"mac": want_mac,
"inhibit_login": True,
}
)
request, channel = self.make_request(
Expand Down

0 comments on commit 665630f

Please sign in to comment.