From ad81066e2a2b53b029a7c941cdaace2b03ea6fff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Humbert?= Date: Fri, 3 Sep 2021 13:56:25 +0200 Subject: [PATCH 1/5] Isolate mediation setup tests in a separate class. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Clément Humbert --- aries_cloudagent/core/tests/test_conductor.py | 204 +++++++++--------- 1 file changed, 105 insertions(+), 99 deletions(-) diff --git a/aries_cloudagent/core/tests/test_conductor.py b/aries_cloudagent/core/tests/test_conductor.py index 705d2ebc74..786c959d89 100644 --- a/aries_cloudagent/core/tests/test_conductor.py +++ b/aries_cloudagent/core/tests/test_conductor.py @@ -722,105 +722,6 @@ async def test_print_invite_connection(self): assert "http://localhost?oob=" in value assert "http://localhost?c_i=" in value - async def test_mediator_invitation_0160(self): - builder: ContextBuilder = StubContextBuilder(self.test_settings) - builder.update_settings({"mediation.invite": "test-invite"}) - builder.update_settings({"mediation.connections_invite": True}) - conductor = test_module.Conductor(builder) - - await conductor.setup() - - mock_conn_record = async_mock.MagicMock() - - with async_mock.patch.object( - test_module.ConnectionInvitation, "from_url" - ) as mock_from_url, async_mock.patch.object( - test_module, - "ConnectionManager", - async_mock.MagicMock( - return_value=async_mock.MagicMock( - receive_invitation=async_mock.CoroutineMock( - return_value=mock_conn_record - ) - ) - ), - ) as mock_mgr, async_mock.patch.object( - mock_conn_record, "metadata_set", async_mock.CoroutineMock() - ), async_mock.patch.object( - test_module, - "LOGGER", - async_mock.MagicMock( - exception=async_mock.MagicMock( - side_effect=Exception("This method should not have been called") - ) - ), - ): - await conductor.start() - await conductor.stop() - mock_from_url.assert_called_once_with("test-invite") - mock_mgr.return_value.receive_invitation.assert_called_once() - - async def test_mediator_invitation_0434(self): - builder: ContextBuilder = StubContextBuilder(self.test_settings) - builder.update_settings({"mediation.invite": "test-invite"}) - conductor = test_module.Conductor(builder) - - await conductor.setup() - - conn_record = ConnRecord( - invitation_key="3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx", - their_label="Hello", - their_role=ConnRecord.Role.RESPONDER.rfc160, - alias="Bob", - ) - conn_record.accept = ConnRecord.ACCEPT_MANUAL - await conn_record.save(await conductor.root_profile.session()) - with async_mock.patch.object( - test_module.InvitationMessage, "from_url" - ) as mock_from_url, async_mock.patch.object( - test_module, - "OutOfBandManager", - async_mock.MagicMock( - return_value=async_mock.MagicMock( - receive_invitation=async_mock.CoroutineMock( - return_value=conn_record - ) - ) - ), - ) as mock_mgr, async_mock.patch.object( - test_module, - "LOGGER", - async_mock.MagicMock( - exception=async_mock.MagicMock( - side_effect=Exception("This method should not have been called") - ) - ), - ): - await conductor.start() - await conductor.stop() - mock_from_url.assert_called_once_with("test-invite") - mock_mgr.return_value.receive_invitation.assert_called_once() - - async def test_mediator_invitation_x(self): - builder: ContextBuilder = StubContextBuilder(self.test_settings) - builder.update_settings({"mediation.invite": "test-invite"}) - builder.update_settings({"mediation.connections_invite": True}) - conductor = test_module.Conductor(builder) - - await conductor.setup() - - with async_mock.patch.object( - test_module.ConnectionInvitation, - "from_url", - async_mock.MagicMock(side_effect=Exception()), - ) as mock_from_url, async_mock.patch.object( - test_module, "LOGGER" - ) as mock_logger: - await conductor.start() - await conductor.stop() - mock_from_url.assert_called_once_with("test-invite") - mock_logger.exception.assert_called_once() - async def test_clear_default_mediator(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) builder.update_settings({"mediation.clear": True}) @@ -946,3 +847,108 @@ async def test_shutdown_multitenant_profiles(self): multitenant_mgr._instances["test1"].close.assert_called_once_with() multitenant_mgr._instances["test2"].close.assert_called_once_with() + + +class TestConductorMediationSetup(AsyncTestCase, Config): + """ + Test related with setting up mediation from given arguments or stored invitation. + """ + + async def test_mediator_invitation_0160(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings) + builder.update_settings({"mediation.invite": "test-invite"}) + builder.update_settings({"mediation.connections_invite": True}) + conductor = test_module.Conductor(builder) + + await conductor.setup() + + mock_conn_record = async_mock.MagicMock() + + with async_mock.patch.object( + test_module.ConnectionInvitation, "from_url" + ) as mock_from_url, async_mock.patch.object( + test_module, + "ConnectionManager", + async_mock.MagicMock( + return_value=async_mock.MagicMock( + receive_invitation=async_mock.CoroutineMock( + return_value=mock_conn_record + ) + ) + ), + ) as mock_mgr, async_mock.patch.object( + mock_conn_record, "metadata_set", async_mock.CoroutineMock() + ), async_mock.patch.object( + test_module, + "LOGGER", + async_mock.MagicMock( + exception=async_mock.MagicMock( + side_effect=Exception("This method should not have been called") + ) + ), + ): + await conductor.start() + await conductor.stop() + mock_from_url.assert_called_once_with("test-invite") + mock_mgr.return_value.receive_invitation.assert_called_once() + + async def test_mediator_invitation_0434(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings) + builder.update_settings({"mediation.invite": "test-invite"}) + conductor = test_module.Conductor(builder) + + await conductor.setup() + + conn_record = ConnRecord( + invitation_key="3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx", + their_label="Hello", + their_role=ConnRecord.Role.RESPONDER.rfc160, + alias="Bob", + ) + conn_record.accept = ConnRecord.ACCEPT_MANUAL + await conn_record.save(await conductor.root_profile.session()) + with async_mock.patch.object( + test_module.InvitationMessage, "from_url" + ) as mock_from_url, async_mock.patch.object( + test_module, + "OutOfBandManager", + async_mock.MagicMock( + return_value=async_mock.MagicMock( + receive_invitation=async_mock.CoroutineMock( + return_value=conn_record + ) + ) + ), + ) as mock_mgr, async_mock.patch.object( + test_module, + "LOGGER", + async_mock.MagicMock( + exception=async_mock.MagicMock( + side_effect=Exception("This method should not have been called") + ) + ), + ): + await conductor.start() + await conductor.stop() + mock_from_url.assert_called_once_with("test-invite") + mock_mgr.return_value.receive_invitation.assert_called_once() + + async def test_mediator_invitation_x(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings) + builder.update_settings({"mediation.invite": "test-invite"}) + builder.update_settings({"mediation.connections_invite": True}) + conductor = test_module.Conductor(builder) + + await conductor.setup() + + with async_mock.patch.object( + test_module.ConnectionInvitation, + "from_url", + async_mock.MagicMock(side_effect=Exception()), + ) as mock_from_url, async_mock.patch.object( + test_module, "LOGGER" + ) as mock_logger: + await conductor.start() + await conductor.stop() + mock_from_url.assert_called_once_with("test-invite") + mock_logger.exception.assert_called_once() From eeb1fbf4217f3b17094a58aba206b673ea61a331 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Humbert?= Date: Mon, 6 Sep 2021 11:13:02 +0200 Subject: [PATCH 2/5] Make mediation invitation parameter idempotent MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit To ease Aca-py's deployment in scenarios where command line arguments cannot be changed easily, make `--mediatior-invitation` and `--mediator-connections-invite` available to the `provision` command. `provision` stores the mediator invite to be used by `start` command. If the invitation changed, `start` updates the stored invite and uses the new one. Signed-off-by: Clément Humbert --- aries_cloudagent/commands/provision.py | 13 ++ .../commands/tests/test_provision.py | 19 ++- aries_cloudagent/config/argparse.py | 61 +++++--- aries_cloudagent/core/conductor.py | 51 ++++--- aries_cloudagent/core/tests/test_conductor.py | 116 ++++++++++++-- .../mediation_invite_store.py | 141 ++++++++++++++++++ .../v1_0/tests/test_mediation_invite_store.py | 120 +++++++++++++++ 7 files changed, 470 insertions(+), 51 deletions(-) create mode 100644 aries_cloudagent/protocols/coordinate_mediation/mediation_invite_store.py create mode 100644 aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_invite_store.py diff --git a/aries_cloudagent/commands/provision.py b/aries_cloudagent/commands/provision.py index 333a7550eb..8d2f0b7bf6 100644 --- a/aries_cloudagent/commands/provision.py +++ b/aries_cloudagent/commands/provision.py @@ -10,6 +10,11 @@ from ..config.ledger import get_genesis_transactions, ledger_config from ..config.util import common_config from ..config.wallet import wallet_config +from ..protocols.coordinate_mediation.mediation_invite_store import ( + MediationInviteStore, + MediationInviteRecord, +) +from ..storage.base import BaseStorage from . import PROG @@ -35,6 +40,14 @@ async def provision(settings: dict): root_profile, public_did = await wallet_config(context, provision=True) + # store mediator invite url if provided + mediation_invite = settings.get("mediation.invite", None) + if mediation_invite: + async with root_profile.session() as session: + await MediationInviteStore(session.context.inject(BaseStorage)).store( + MediationInviteRecord.unused(mediation_invite) + ) + if await ledger_config(root_profile, public_did and public_did.did, True): print("Ledger configured") else: diff --git a/aries_cloudagent/commands/tests/test_provision.py b/aries_cloudagent/commands/tests/test_provision.py index adfa442c40..f8fa55ea3a 100644 --- a/aries_cloudagent/commands/tests/test_provision.py +++ b/aries_cloudagent/commands/tests/test_provision.py @@ -4,8 +4,10 @@ from ...config.base import ConfigError from ...config.error import ArgsParseError -from ...core.profile import Profile from .. import provision as test_module +from ...protocols.coordinate_mediation.mediation_invite_store import ( + MediationInviteRecord, +) class TestProvision(AsyncTestCase): @@ -68,3 +70,18 @@ def test_main(self): ) as mock_execute: test_module.main() mock_execute.assert_called_once + + async def test_provision_should_store_provided_mediation_invite(self): + # given + mediation_invite = "test-invite" + + with async_mock.patch.object( + test_module.MediationInviteStore, "store" + ) as invite_store: + # when + await test_module.provision({"mediation.invite": mediation_invite}) + + # then + invite_store.assert_called_with( + MediationInviteRecord(mediation_invite, False) + ) diff --git a/aries_cloudagent/config/argparse.py b/aries_cloudagent/config/argparse.py index b003e96ce3..48502debd7 100644 --- a/aries_cloudagent/config/argparse.py +++ b/aries_cloudagent/config/argparse.py @@ -1139,25 +1139,18 @@ def get_settings(self, args: Namespace): return settings -@group(CAT_START) -class MediationGroup(ArgumentGroup): - """Mediation settings.""" +@group(CAT_START, CAT_PROVISION) +class MediationInviteGroup(ArgumentGroup): + """ + Mediation invitation settings. - GROUP_NAME = "Mediation" + These can be provided at provision- and start-time. + """ + + GROUP_NAME = "Mediation invitation" def add_arguments(self, parser: ArgumentParser): - """Add mediation command line arguments to the parser.""" - parser.add_argument( - "--open-mediation", - action="store_true", - env_var="ACAPY_MEDIATION_OPEN", - help=( - "Enables didcomm mediation. After establishing a connection, " - "if enabled, an agent may request message mediation, which will " - "allow the mediator to forward messages on behalf of the recipient. " - "See aries-rfc:0211." - ), - ) + """Add mediation invitation command line arguments to the parser.""" parser.add_argument( "--mediator-invitation", type=str, @@ -1178,6 +1171,38 @@ def add_arguments(self, parser: ArgumentParser): "Default: false." ), ) + + def get_settings(self, args: Namespace): + """Extract mediation invitation settings.""" + settings = {} + if args.mediator_invitation: + settings["mediation.invite"] = args.mediator_invitation + if args.mediator_connections_invite: + settings["mediation.connections_invite"] = True + + return settings + + +@group(CAT_START) +class MediationGroup(ArgumentGroup): + """Mediation settings.""" + + GROUP_NAME = "Mediation" + + def add_arguments(self, parser: ArgumentParser): + """Add mediation command line arguments to the parser.""" + parser.add_argument( + "--open-mediation", + action="store_true", + env_var="ACAPY_MEDIATION_OPEN", + help=( + "Enables didcomm mediation. After establishing a connection, " + "if enabled, an agent may request message mediation, which will " + "allow the mediator to forward messages on behalf of the recipient. " + "See aries-rfc:0211." + ), + ) + parser.add_argument( "--default-mediator-id", type=str, @@ -1197,14 +1222,10 @@ def get_settings(self, args: Namespace): settings = {} if args.open_mediation: settings["mediation.open"] = True - if args.mediator_invitation: - settings["mediation.invite"] = args.mediator_invitation if args.default_mediator_id: settings["mediation.default_id"] = args.default_mediator_id if args.clear_default_mediator: settings["mediation.clear"] = True - if args.mediator_connections_invite: - settings["mediation.connections_invite"] = True if args.clear_default_mediator and args.default_mediator_id: raise ArgsParseError( diff --git a/aries_cloudagent/core/conductor.py b/aries_cloudagent/core/conductor.py index d57b0ad6a4..b7fc0098ed 100644 --- a/aries_cloudagent/core/conductor.py +++ b/aries_cloudagent/core/conductor.py @@ -31,8 +31,10 @@ ConnectionInvitation, ) from ..protocols.coordinate_mediation.v1_0.manager import MediationManager +from ..protocols.coordinate_mediation.mediation_invite_store import MediationInviteStore from ..protocols.out_of_band.v1_0.manager import OutOfBandManager from ..protocols.out_of_band.v1_0.messages.invitation import HSProto, InvitationMessage +from ..storage.base import BaseStorage from ..transport.inbound.manager import InboundTransportManager from ..transport.inbound.message import InboundMessage from ..transport.outbound.base import OutboundDeliveryError @@ -315,26 +317,41 @@ async def start(self) -> None: ) async with self.root_profile.session() as session: - mgr = ( - ConnectionManager(session) - if mediation_connections_invite - else OutOfBandManager(session) + invite_store = MediationInviteStore( + session.context.inject(BaseStorage) ) - - conn_record = await mgr.receive_invitation( - invitation=invitation_handler.from_url(mediation_invitation), - auto_accept=True, + default_invite_record = ( + await invite_store.retrieve_and_update_mediation_record( + mediation_invitation + ) ) - await conn_record.metadata_set( - session, MediationManager.SEND_REQ_AFTER_CONNECTION, True - ) - await conn_record.metadata_set( - session, MediationManager.SET_TO_DEFAULT_ON_GRANTED, True - ) - print("Attempting to connect to mediator...") - del mgr - except Exception: + if not default_invite_record.used: + mgr = ( + ConnectionManager(session) + if mediation_connections_invite + else OutOfBandManager(session) + ) + + conn_record = await mgr.receive_invitation( + invitation=invitation_handler.from_url( + default_invite_record.invite + ), + auto_accept=True, + ) + await invite_store.mark_default_invite_as_used() + + await conn_record.metadata_set( + session, MediationManager.SEND_REQ_AFTER_CONNECTION, True + ) + await conn_record.metadata_set( + session, MediationManager.SET_TO_DEFAULT_ON_GRANTED, True + ) + + print("Attempting to connect to mediator...") + del mgr + except Exception as e: + print(e) LOGGER.exception("Error accepting mediation invitation") async def stop(self, timeout=1.0): diff --git a/aries_cloudagent/core/tests/test_conductor.py b/aries_cloudagent/core/tests/test_conductor.py index 786c959d89..76d292e65a 100644 --- a/aries_cloudagent/core/tests/test_conductor.py +++ b/aries_cloudagent/core/tests/test_conductor.py @@ -1,5 +1,6 @@ from io import StringIO +import asynctest from asynctest import TestCase as AsyncTestCase from asynctest import mock as async_mock @@ -17,6 +18,9 @@ from ...core.in_memory import InMemoryProfileManager from ...core.profile import ProfileManager from ...core.protocol_registry import ProtocolRegistry +from ...protocols.coordinate_mediation.mediation_invite_store import ( + MediationInviteRecord, +) from ...protocols.coordinate_mediation.v1_0.models.mediation_record import ( MediationRecord, ) @@ -854,12 +858,35 @@ class TestConductorMediationSetup(AsyncTestCase, Config): Test related with setting up mediation from given arguments or stored invitation. """ - async def test_mediator_invitation_0160(self): + def __get_invite_store_mock( + self, invite_string: str, invite_already_used: bool = False + ) -> async_mock.MagicMock: + unused_invite = MediationInviteRecord(invite_string, invite_already_used) + used_invite = MediationInviteRecord(invite_string, used=True) + + return async_mock.MagicMock( + retrieve_and_update_mediation_record=async_mock.CoroutineMock( + return_value=unused_invite + ), + mark_default_invite_as_used=async_mock.CoroutineMock( + return_value=used_invite + ), + ) + + def __get_mediator_config( + self, invite_string: str, connections_invite: bool = False + ) -> ContextBuilder: builder: ContextBuilder = StubContextBuilder(self.test_settings) - builder.update_settings({"mediation.invite": "test-invite"}) - builder.update_settings({"mediation.connections_invite": True}) - conductor = test_module.Conductor(builder) + builder.update_settings({"mediation.invite": invite_string}) + if connections_invite: + builder.update_settings({"mediation.connections_invite": True}) + + return builder + async def test_mediator_invitation_0160(self): + conductor = test_module.Conductor( + self.__get_mediator_config("test-invite", True) + ) await conductor.setup() mock_conn_record = async_mock.MagicMock() @@ -893,10 +920,9 @@ async def test_mediator_invitation_0160(self): mock_mgr.return_value.receive_invitation.assert_called_once() async def test_mediator_invitation_0434(self): - builder: ContextBuilder = StubContextBuilder(self.test_settings) - builder.update_settings({"mediation.invite": "test-invite"}) - conductor = test_module.Conductor(builder) - + conductor = test_module.Conductor( + self.__get_mediator_config("test-invite", False) + ) await conductor.setup() conn_record = ConnRecord( @@ -933,12 +959,76 @@ async def test_mediator_invitation_0434(self): mock_from_url.assert_called_once_with("test-invite") mock_mgr.return_value.receive_invitation.assert_called_once() - async def test_mediator_invitation_x(self): - builder: ContextBuilder = StubContextBuilder(self.test_settings) - builder.update_settings({"mediation.invite": "test-invite"}) - builder.update_settings({"mediation.connections_invite": True}) - conductor = test_module.Conductor(builder) + async def test_mediation_invitation_should_use_stored_invitation(self): + # given + invite_string = "test-invite" + conductor = test_module.Conductor( + self.__get_mediator_config(invite_string, True) + ) + await conductor.setup() + mock_conn_record = async_mock.MagicMock() + + invite_store_mock = self.__get_invite_store_mock(invite_string) + connection_manager_mock = async_mock.MagicMock( + receive_invitation=async_mock.CoroutineMock(return_value=mock_conn_record) + ) + + # when + with async_mock.patch.object( + test_module, "MediationInviteStore", return_value=invite_store_mock + ), async_mock.patch.object( + test_module.ConnectionInvitation, "from_url" + ) as mock_connection_from_url, async_mock.patch.object( + test_module, "ConnectionManager", return_value=connection_manager_mock + ), async_mock.patch.object( + mock_conn_record, "metadata_set", async_mock.CoroutineMock() + ): + await conductor.start() + await conductor.stop() + + # then + invite_store_mock.retrieve_and_update_mediation_record.assert_called_with( + invite_string + ) + connection_manager_mock.receive_invitation.assert_called_once() + mock_connection_from_url.assert_called_with(invite_string) + + async def test_mediation_invitation_should_not_establish_new_connection_for_used_invitation( + self, + ): + # given + invite_string = "test-invite" + + conductor = test_module.Conductor( + self.__get_mediator_config(invite_string, True) + ) + await conductor.setup() + + invite_store_mock = self.__get_invite_store_mock(invite_string, True) + connection_manager_mock = async_mock.MagicMock( + receive_invitation=async_mock.CoroutineMock() + ) + + # when + with async_mock.patch.object( + test_module, "MediationInviteStore", return_value=invite_store_mock + ), async_mock.patch.object( + test_module, "ConnectionManager", return_value=connection_manager_mock + ): + await conductor.start() + await conductor.stop() + + # then + invite_store_mock.retrieve_and_update_mediation_record.assert_called_with( + invite_string + ) + connection_manager_mock.receive_invitation.assert_not_called() + + async def test_mediator_invitation_x(self): + conductor = test_module.Conductor( + self.__get_mediator_config("test-invite", True) + ) await conductor.setup() with async_mock.patch.object( diff --git a/aries_cloudagent/protocols/coordinate_mediation/mediation_invite_store.py b/aries_cloudagent/protocols/coordinate_mediation/mediation_invite_store.py new file mode 100644 index 0000000000..7b2adf16fd --- /dev/null +++ b/aries_cloudagent/protocols/coordinate_mediation/mediation_invite_store.py @@ -0,0 +1,141 @@ +""" +Storage management for configuration-provided mediation invite. + +Handle storage and retrieval of mediation invites provided through arguments. +Enables having the mediation invite config be the same +for `provision` and `starting` commands. +""" +import dataclasses +import json +from typing import Optional + +from aries_cloudagent.storage.base import BaseStorage +from aries_cloudagent.storage.error import StorageNotFoundError +from aries_cloudagent.storage.record import StorageRecord + + +@dataclasses.dataclass +class MediationInviteRecord: + """A record to store mediation invites and their freshness.""" + + invite: str + used: bool + + def to_json(self) -> str: + """:return: The current record serialized into a json string.""" + return json.dumps({"invite": self.invite, "used": self.used}) + + @staticmethod + def from_json(json_invite_record: str) -> "MediationInviteRecord": + """:return: a mediation invite record deserialized from a json string.""" + return MediationInviteRecord(**json.loads(json_invite_record)) + + @staticmethod + def unused(invite: str) -> "MediationInviteRecord": + """ + :param invite: invite string as provided by the mediator. + + :return: An unused mediation invitation for the given invite string + """ + return MediationInviteRecord(invite, False) + + +class NoDefaultMediationInviteException(Exception): + """Raised if trying to mark a default invite as used when none exist.""" + + +class MediationInviteStore: + """Store and retrieve mediation invite configuration.""" + + INVITE_RECORD_CATEGORY = "config" + MEDIATION_INVITE_ID = "mediation_invite" + + def __init__(self, storage: BaseStorage): + """:param storage: storage facility to be used to store mediation invitation.""" + self.__storage = storage + + async def __retrieve_record(self, key: str) -> Optional[StorageRecord]: + try: + return await self.__storage.get_record(self.INVITE_RECORD_CATEGORY, key) + except StorageNotFoundError: + return None + + async def store( + self, mediation_invite: MediationInviteRecord + ) -> MediationInviteRecord: + """ + Store the mediator's invite for further use when starting the agent. + + Update the currently stored invite if one already exists. + This assumes a new invite and as such, marks it as unused. + + :param mediation_invite: mediation invite url + :return: stored mediation invite + """ + current_invite_record = await self.__retrieve_record(self.MEDIATION_INVITE_ID) + + if current_invite_record is None: + await self.__storage.add_record( + StorageRecord( + type=self.INVITE_RECORD_CATEGORY, + id=self.MEDIATION_INVITE_ID, + value=mediation_invite.to_json(), + ) + ) + else: + await self.__storage.update_record( + current_invite_record, + mediation_invite.to_json(), + tags=current_invite_record.tags, + ) + + return mediation_invite + + async def __retrieve(self) -> Optional[MediationInviteRecord]: + """:return: the currently stored mediation invite url.""" + + invite_record = await self.__retrieve_record(self.MEDIATION_INVITE_ID) + return ( + MediationInviteRecord.from_json(invite_record.value) + if invite_record is not None + else None + ) + + async def retrieve_and_update_mediation_record( + self, provided_mediation_invitation: str + ) -> MediationInviteRecord: + """ + Retrieve stored mediation invite and optionally updates it. + + Stored value is updated if `provided_mediation_invitation` has changed. + Updated record is marked as unused. + + :param provided_mediation_invitation: mediation invite provided by user + :return: stored mediation invite + """ + default_invite = await self.__retrieve() + + if default_invite != provided_mediation_invitation: + default_invite = await self.store( + MediationInviteRecord.unused(provided_mediation_invitation) + ) + + return default_invite + + async def mark_default_invite_as_used(self): + """ + Mark the currently stored invitation as used if one exists. + + :raises NoDefaultMediationInviteException: + if trying to mark invite as used when there is no invite stored. + """ + record = await self.__retrieve() + if not record: + raise NoDefaultMediationInviteException( + "No default mediation invite: cannot mark it as used." + ) + + updated_record = MediationInviteRecord(record.invite, used=True) + await self.store(updated_record) + + return updated_record diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_invite_store.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_invite_store.py new file mode 100644 index 0000000000..123a53a22f --- /dev/null +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_invite_store.py @@ -0,0 +1,120 @@ +from asynctest import TestCase as AsyncTestCase +from unittest import TestCase +from asynctest import mock as async_mock + +from aries_cloudagent.protocols.coordinate_mediation.mediation_invite_store import ( + MediationInviteStore, + MediationInviteRecord, + NoDefaultMediationInviteException, +) +from aries_cloudagent.storage.base import BaseStorage +from aries_cloudagent.storage.error import StorageNotFoundError +from aries_cloudagent.storage.record import StorageRecord + + +def _storage_record_for(value: str, used: bool = False) -> StorageRecord: + return StorageRecord( + type=MediationInviteStore.INVITE_RECORD_CATEGORY, + value=f"""{{"invite": "{value}", "used": {str(used).lower()}}}""", + tags={}, + id=MediationInviteStore.MEDIATION_INVITE_ID, + ) + + +class TestMediationInviteRecord(TestCase): + def test_to_json_should_dump_record(self): + # given + invite_record = MediationInviteRecord("some_invite", True) + + # when + json_record = invite_record.to_json() + + # then + assert json_record == """{"invite": "some_invite", "used": true}""" + + def test_from_json_should_create_record_from_json(self): + # given + json_record = """{"invite": "some_invite", "used": true}""" + + # when + record = MediationInviteRecord.from_json(json_record) + + # then + assert record == MediationInviteRecord("some_invite", True) + + def test_unused_should_create_unused_record(self): + # when - then + assert not MediationInviteRecord.unused("some_other_invite").used + + +class TestMediationInviteStore(AsyncTestCase): + def setUp(self): + self.storage = async_mock.MagicMock(spec=BaseStorage) + self.mediation_invite_store = MediationInviteStore(self.storage) + + async def test_retrieve_update_should_create_record_to_store_mediation_invite_when_no_record_exists( + self, + ): + # given + mediation_invite_url = "somepla.ce:4242/alongandunreadablebase64payload" + self.storage.get_record.side_effect = StorageNotFoundError + + expected_updated_record = MediationInviteRecord.unused(mediation_invite_url) + + # when + stored_invite = ( + await self.mediation_invite_store.retrieve_and_update_mediation_record( + mediation_invite_url + ) + ) + + # then + self.storage.add_record.assert_called_with( + _storage_record_for(mediation_invite_url) + ) + assert stored_invite == expected_updated_record + + async def test_retrieve_update_should_update_record_when_a_mediation_invite_record_exists( + self, + ): + # given + stored_record = _storage_record_for("some old url") + mediation_invite_url = "somepla.ce:4242/alongandunreadablebase64payload" + self.storage.get_record.return_value = stored_record + + expected_updated_record = MediationInviteRecord.unused(mediation_invite_url) + + # when + stored_invite = ( + await self.mediation_invite_store.retrieve_and_update_mediation_record( + mediation_invite_url + ) + ) + + # then + self.storage.update_record.assert_called_with( + stored_record, expected_updated_record.to_json(), tags=stored_record.tags + ) + + assert stored_invite == expected_updated_record + + async def test_mark_default_invite_as_used_should_mark_stored_invite(self): + # given + stored_record = _storage_record_for("some old url") + self.storage.get_record.return_value = stored_record + + # when + updated_invite_record = ( + await self.mediation_invite_store.mark_default_invite_as_used() + ) + + # then + assert updated_invite_record.used + + async def test_mark_default_invite_as_used_should_raise_when_no_invite(self): + # given + self.storage.get_record.return_value = None + + # when - then + with self.assertRaises(NoDefaultMediationInviteException): + await self.mediation_invite_store.mark_default_invite_as_used() From 2204e25aa547a7710861372ad9158db1becec7aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Humbert?= Date: Mon, 13 Sep 2021 15:00:22 +0200 Subject: [PATCH 3/5] When setting up mediation: clear previous default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Whenever the conductor needs to setup a new mediation connection (eg: the invitation url provided by the command line arguments has changed), clear the old default mediator settings prior to setting up the new connection and default mediator. Signed-off-by: Clément Humbert --- aries_cloudagent/core/conductor.py | 4 ++++ aries_cloudagent/core/tests/test_conductor.py | 14 +++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/aries_cloudagent/core/conductor.py b/aries_cloudagent/core/conductor.py index b7fc0098ed..5ed3a6e778 100644 --- a/aries_cloudagent/core/conductor.py +++ b/aries_cloudagent/core/conductor.py @@ -327,6 +327,10 @@ async def start(self) -> None: ) if not default_invite_record.used: + # clear previous mediator configuration before establishing a + # new one + await MediationManager(session.profile).clear_default_mediator() + mgr = ( ConnectionManager(session) if mediation_connections_invite diff --git a/aries_cloudagent/core/tests/test_conductor.py b/aries_cloudagent/core/tests/test_conductor.py index 76d292e65a..8fe16ed448 100644 --- a/aries_cloudagent/core/tests/test_conductor.py +++ b/aries_cloudagent/core/tests/test_conductor.py @@ -960,6 +960,12 @@ async def test_mediator_invitation_0434(self): mock_mgr.return_value.receive_invitation.assert_called_once() async def test_mediation_invitation_should_use_stored_invitation(self): + """ + Conductor should store the mediation invite if it differs from the stored one or + if the stored one was not used yet. + + Using a mediation invitation should clear the previously set default mediator. + """ # given invite_string = "test-invite" @@ -973,6 +979,9 @@ async def test_mediation_invitation_should_use_stored_invitation(self): connection_manager_mock = async_mock.MagicMock( receive_invitation=async_mock.CoroutineMock(return_value=mock_conn_record) ) + mock_mediation_manager = async_mock.MagicMock( + clear_default_mediator=async_mock.CoroutineMock() + ) # when with async_mock.patch.object( @@ -983,6 +992,8 @@ async def test_mediation_invitation_should_use_stored_invitation(self): test_module, "ConnectionManager", return_value=connection_manager_mock ), async_mock.patch.object( mock_conn_record, "metadata_set", async_mock.CoroutineMock() + ), async_mock.patch.object( + test_module, 'MediationManager', return_value=mock_mediation_manager ): await conductor.start() await conductor.stop() @@ -993,8 +1004,9 @@ async def test_mediation_invitation_should_use_stored_invitation(self): ) connection_manager_mock.receive_invitation.assert_called_once() mock_connection_from_url.assert_called_with(invite_string) + mock_mediation_manager.clear_default_mediator.assert_called_once() - async def test_mediation_invitation_should_not_establish_new_connection_for_used_invitation( + async def test_mediation_invitation_should_not_create_connection_for_old_invitation( self, ): # given From 500c52224c3267bfacd6758691078f5da14227a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Humbert?= Date: Tue, 14 Sep 2021 16:18:39 +0200 Subject: [PATCH 4/5] Left-behind formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Clément Humbert --- aries_cloudagent/core/tests/test_conductor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aries_cloudagent/core/tests/test_conductor.py b/aries_cloudagent/core/tests/test_conductor.py index 6ce1319d2d..2812770b2f 100644 --- a/aries_cloudagent/core/tests/test_conductor.py +++ b/aries_cloudagent/core/tests/test_conductor.py @@ -1005,7 +1005,7 @@ async def test_mediation_invitation_should_use_stored_invitation(self): ), async_mock.patch.object( mock_conn_record, "metadata_set", async_mock.CoroutineMock() ), async_mock.patch.object( - test_module, 'MediationManager', return_value=mock_mediation_manager + test_module, "MediationManager", return_value=mock_mediation_manager ): await conductor.start() await conductor.stop() From 59a62b1b11bca39c8f5a00a259cf6d8d33203690 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Humbert?= Date: Fri, 17 Sep 2021 09:24:21 +0200 Subject: [PATCH 5/5] Mediation invite is not mandatory if provided at provision time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mediation invite can be: * Specified at provision time and repeated at start time without creating more connections that necessary * Specified at provision and not repeated afterwards Stored invite is updated if a new invite is provided through arguments. Signed-off-by: Clément Humbert --- aries_cloudagent/core/conductor.py | 50 +++---- aries_cloudagent/core/tests/test_conductor.py | 125 +++++++++--------- .../mediation_invite_store.py | 35 ++++- .../v1_0/tests/test_mediation_invite_store.py | 61 +++++++-- 4 files changed, 170 insertions(+), 101 deletions(-) diff --git a/aries_cloudagent/core/conductor.py b/aries_cloudagent/core/conductor.py index 8f4daba481..4c10235b36 100644 --- a/aries_cloudagent/core/conductor.py +++ b/aries_cloudagent/core/conductor.py @@ -305,30 +305,31 @@ async def start(self) -> None: except Exception: LOGGER.exception("Error creating invitation") - # Accept mediation invitation if specified - mediation_invitation: str = context.settings.get("mediation.invite") - if mediation_invitation: + # mediation connection establishment + provided_invite: str = context.settings.get("mediation.invite") + async with self.root_profile.session() as session: try: - mediation_connections_invite = context.settings.get( - "mediation.connections_invite", False - ) - invitation_handler = ( - ConnectionInvitation - if mediation_connections_invite - else InvitationMessage + invite_store = MediationInviteStore(session.context.inject(BaseStorage)) + mediation_invite_record = ( + await invite_store.get_mediation_invite_record(provided_invite) ) + except Exception: + LOGGER.exception("Error retrieving mediator invitation") + mediation_invite_record = None - async with self.root_profile.session() as session: - invite_store = MediationInviteStore( - session.context.inject(BaseStorage) + # Accept mediation invitation if one was specified or stored + if mediation_invite_record is not None: + try: + mediation_connections_invite = context.settings.get( + "mediation.connections_invite", False ) - default_invite_record = ( - await invite_store.retrieve_and_update_mediation_record( - mediation_invitation - ) + invitation_handler = ( + ConnectionInvitation + if mediation_connections_invite + else InvitationMessage ) - if not default_invite_record.used: + if not mediation_invite_record.used: # clear previous mediator configuration before establishing a # new one await MediationManager(session.profile).clear_default_mediator() @@ -341,11 +342,15 @@ async def start(self) -> None: conn_record = await mgr.receive_invitation( invitation=invitation_handler.from_url( - default_invite_record.invite + mediation_invite_record.invite ), auto_accept=True, ) - await invite_store.mark_default_invite_as_used() + await ( + MediationInviteStore( + session.context.inject(BaseStorage) + ).mark_default_invite_as_used() + ) await conn_record.metadata_set( session, MediationManager.SEND_REQ_AFTER_CONNECTION, True @@ -356,9 +361,8 @@ async def start(self) -> None: print("Attempting to connect to mediator...") del mgr - except Exception as e: - print(e) - LOGGER.exception("Error accepting mediation invitation") + except Exception: + LOGGER.exception("Error accepting mediation invitation") async def stop(self, timeout=1.0): """Stop the agent.""" diff --git a/aries_cloudagent/core/tests/test_conductor.py b/aries_cloudagent/core/tests/test_conductor.py index 2812770b2f..d703660757 100644 --- a/aries_cloudagent/core/tests/test_conductor.py +++ b/aries_cloudagent/core/tests/test_conductor.py @@ -20,6 +20,7 @@ from ...core.protocol_registry import ProtocolRegistry from ...protocols.coordinate_mediation.mediation_invite_store import ( MediationInviteRecord, + MediationInviteStore, ) from ...protocols.coordinate_mediation.v1_0.models.mediation_record import ( MediationRecord, @@ -865,26 +866,25 @@ async def test_shutdown_multitenant_profiles(self): multitenant_mgr._instances["test2"].close.assert_called_once_with() +def get_invite_store_mock( + invite_string: str, invite_already_used: bool = False +) -> async_mock.MagicMock: + unused_invite = MediationInviteRecord(invite_string, invite_already_used) + used_invite = MediationInviteRecord(invite_string, used=True) + + return async_mock.MagicMock( + get_mediation_invite_record=async_mock.CoroutineMock( + return_value=unused_invite + ), + mark_default_invite_as_used=async_mock.CoroutineMock(return_value=used_invite), + ) + + class TestConductorMediationSetup(AsyncTestCase, Config): """ Test related with setting up mediation from given arguments or stored invitation. """ - def __get_invite_store_mock( - self, invite_string: str, invite_already_used: bool = False - ) -> async_mock.MagicMock: - unused_invite = MediationInviteRecord(invite_string, invite_already_used) - used_invite = MediationInviteRecord(invite_string, used=True) - - return async_mock.MagicMock( - retrieve_and_update_mediation_record=async_mock.CoroutineMock( - return_value=unused_invite - ), - mark_default_invite_as_used=async_mock.CoroutineMock( - return_value=used_invite - ), - ) - def __get_mediator_config( self, invite_string: str, connections_invite: bool = False ) -> ContextBuilder: @@ -895,7 +895,13 @@ def __get_mediator_config( return builder - async def test_mediator_invitation_0160(self): + @asynctest.patch.object( + test_module, + "MediationInviteStore", + return_value=get_invite_store_mock("test-invite"), + ) + @asynctest.patch.object(test_module.ConnectionInvitation, "from_url") + async def test_mediator_invitation_0160(self, mock_from_url, _): conductor = test_module.Conductor( self.__get_mediator_config("test-invite", True) ) @@ -904,8 +910,6 @@ async def test_mediator_invitation_0160(self): mock_conn_record = async_mock.MagicMock() with async_mock.patch.object( - test_module.ConnectionInvitation, "from_url" - ) as mock_from_url, async_mock.patch.object( test_module, "ConnectionManager", async_mock.MagicMock( @@ -917,21 +921,19 @@ async def test_mediator_invitation_0160(self): ), ) as mock_mgr, async_mock.patch.object( mock_conn_record, "metadata_set", async_mock.CoroutineMock() - ), async_mock.patch.object( - test_module, - "LOGGER", - async_mock.MagicMock( - exception=async_mock.MagicMock( - side_effect=Exception("This method should not have been called") - ) - ), ): await conductor.start() await conductor.stop() mock_from_url.assert_called_once_with("test-invite") mock_mgr.return_value.receive_invitation.assert_called_once() - async def test_mediator_invitation_0434(self): + @asynctest.patch.object( + test_module, + "MediationInviteStore", + return_value=get_invite_store_mock("test-invite"), + ) + @asynctest.patch.object(test_module.InvitationMessage, "from_url") + async def test_mediator_invitation_0434(self, mock_from_url, _): conductor = test_module.Conductor( self.__get_mediator_config("test-invite", False) ) @@ -946,8 +948,6 @@ async def test_mediator_invitation_0434(self): conn_record.accept = ConnRecord.ACCEPT_MANUAL await conn_record.save(await conductor.root_profile.session()) with async_mock.patch.object( - test_module.InvitationMessage, "from_url" - ) as mock_from_url, async_mock.patch.object( test_module, "OutOfBandManager", async_mock.MagicMock( @@ -957,21 +957,17 @@ async def test_mediator_invitation_0434(self): ) ) ), - ) as mock_mgr, async_mock.patch.object( - test_module, - "LOGGER", - async_mock.MagicMock( - exception=async_mock.MagicMock( - side_effect=Exception("This method should not have been called") - ) - ), - ): + ) as mock_mgr: await conductor.start() await conductor.stop() mock_from_url.assert_called_once_with("test-invite") mock_mgr.return_value.receive_invitation.assert_called_once() - async def test_mediation_invitation_should_use_stored_invitation(self): + @asynctest.patch.object(test_module, "MediationInviteStore") + @asynctest.patch.object(test_module.ConnectionInvitation, "from_url") + async def test_mediation_invitation_should_use_stored_invitation( + self, patched_from_url, patched_invite_store + ): """ Conductor should store the mediation invite if it differs from the stored one or if the stored one was not used yet. @@ -986,8 +982,9 @@ async def test_mediation_invitation_should_use_stored_invitation(self): ) await conductor.setup() mock_conn_record = async_mock.MagicMock() + mocked_store = get_invite_store_mock(invite_string) + patched_invite_store.return_value = mocked_store - invite_store_mock = self.__get_invite_store_mock(invite_string) connection_manager_mock = async_mock.MagicMock( receive_invitation=async_mock.CoroutineMock(return_value=mock_conn_record) ) @@ -997,10 +994,6 @@ async def test_mediation_invitation_should_use_stored_invitation(self): # when with async_mock.patch.object( - test_module, "MediationInviteStore", return_value=invite_store_mock - ), async_mock.patch.object( - test_module.ConnectionInvitation, "from_url" - ) as mock_connection_from_url, async_mock.patch.object( test_module, "ConnectionManager", return_value=connection_manager_mock ), async_mock.patch.object( mock_conn_record, "metadata_set", async_mock.CoroutineMock() @@ -1011,15 +1004,16 @@ async def test_mediation_invitation_should_use_stored_invitation(self): await conductor.stop() # then - invite_store_mock.retrieve_and_update_mediation_record.assert_called_with( - invite_string - ) + mocked_store.get_mediation_invite_record.assert_called_with(invite_string) + connection_manager_mock.receive_invitation.assert_called_once() - mock_connection_from_url.assert_called_with(invite_string) + patched_from_url.assert_called_with(invite_string) mock_mediation_manager.clear_default_mediator.assert_called_once() + @asynctest.patch.object(test_module, "MediationInviteStore") + @asynctest.patch.object(test_module, "ConnectionManager") async def test_mediation_invitation_should_not_create_connection_for_old_invitation( - self, + self, patched_connection_manager, patched_invite_store ): # given invite_string = "test-invite" @@ -1029,27 +1023,28 @@ async def test_mediation_invitation_should_not_create_connection_for_old_invitat ) await conductor.setup() - invite_store_mock = self.__get_invite_store_mock(invite_string, True) + invite_store_mock = get_invite_store_mock(invite_string, True) + patched_invite_store.return_value = invite_store_mock + connection_manager_mock = async_mock.MagicMock( receive_invitation=async_mock.CoroutineMock() ) + patched_connection_manager.return_value = connection_manager_mock # when - with async_mock.patch.object( - test_module, "MediationInviteStore", return_value=invite_store_mock - ), async_mock.patch.object( - test_module, "ConnectionManager", return_value=connection_manager_mock - ): - await conductor.start() - await conductor.stop() - - # then - invite_store_mock.retrieve_and_update_mediation_record.assert_called_with( - invite_string - ) - connection_manager_mock.receive_invitation.assert_not_called() - - async def test_mediator_invitation_x(self): + await conductor.start() + await conductor.stop() + + # then + invite_store_mock.get_mediation_invite_record.assert_called_with(invite_string) + connection_manager_mock.receive_invitation.assert_not_called() + + @asynctest.patch.object( + test_module, + "MediationInviteStore", + return_value=get_invite_store_mock("test-invite"), + ) + async def test_mediator_invitation_x(self, _): conductor = test_module.Conductor( self.__get_mediator_config("test-invite", True) ) diff --git a/aries_cloudagent/protocols/coordinate_mediation/mediation_invite_store.py b/aries_cloudagent/protocols/coordinate_mediation/mediation_invite_store.py index 7b2adf16fd..d40c58602b 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/mediation_invite_store.py +++ b/aries_cloudagent/protocols/coordinate_mediation/mediation_invite_store.py @@ -101,13 +101,13 @@ async def __retrieve(self) -> Optional[MediationInviteRecord]: else None ) - async def retrieve_and_update_mediation_record( + async def __update_mediation_record( self, provided_mediation_invitation: str ) -> MediationInviteRecord: """ - Retrieve stored mediation invite and optionally updates it. + Update the stored invitation when a new invitation is provided. - Stored value is updated if `provided_mediation_invitation` has changed. + Stored value is only updated if `provided_mediation_invitation` has changed. Updated record is marked as unused. :param provided_mediation_invitation: mediation invite provided by user @@ -139,3 +139,32 @@ async def mark_default_invite_as_used(self): await self.store(updated_record) return updated_record + + async def get_mediation_invite_record( + self, provided_mediation_invitation: Optional[str] + ) -> Optional[MediationInviteRecord]: + """ + Provide the MediationInviteRecord to use/that was used for mediation. + + Returned record may have been used already. + + Stored record is updated if `provided_mediation_invitation` has changed. + Updated record is marked as unused. + + :param provided_mediation_invitation: mediation invite provided by user + :return: mediation invite to use/that was used to connect to the mediator. None if + no invitation was provided/provisioned. + """ + + stored_invite = await self.__retrieve() + + if stored_invite is None and provided_mediation_invitation is None: + return None + elif stored_invite is None and provided_mediation_invitation is not None: + return await self.store( + MediationInviteRecord.unused(provided_mediation_invitation) + ) + elif stored_invite is not None and provided_mediation_invitation is None: + return stored_invite + elif stored_invite is not None and provided_mediation_invitation is not None: + return await self.__update_mediation_record(provided_mediation_invitation) diff --git a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_invite_store.py b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_invite_store.py index 123a53a22f..6cf8857a9f 100644 --- a/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_invite_store.py +++ b/aries_cloudagent/protocols/coordinate_mediation/v1_0/tests/test_mediation_invite_store.py @@ -52,7 +52,7 @@ def setUp(self): self.storage = async_mock.MagicMock(spec=BaseStorage) self.mediation_invite_store = MediationInviteStore(self.storage) - async def test_retrieve_update_should_create_record_to_store_mediation_invite_when_no_record_exists( + async def test_store_create_record_to_store_mediation_invite_when_no_record_exists( self, ): # given @@ -62,10 +62,8 @@ async def test_retrieve_update_should_create_record_to_store_mediation_invite_wh expected_updated_record = MediationInviteRecord.unused(mediation_invite_url) # when - stored_invite = ( - await self.mediation_invite_store.retrieve_and_update_mediation_record( - mediation_invite_url - ) + stored_invite = await self.mediation_invite_store.store( + MediationInviteRecord.unused(mediation_invite_url) ) # then @@ -74,7 +72,7 @@ async def test_retrieve_update_should_create_record_to_store_mediation_invite_wh ) assert stored_invite == expected_updated_record - async def test_retrieve_update_should_update_record_when_a_mediation_invite_record_exists( + async def test_store_should_update_record_when_a_mediation_invite_record_exists( self, ): # given @@ -85,10 +83,8 @@ async def test_retrieve_update_should_update_record_when_a_mediation_invite_reco expected_updated_record = MediationInviteRecord.unused(mediation_invite_url) # when - stored_invite = ( - await self.mediation_invite_store.retrieve_and_update_mediation_record( - mediation_invite_url - ) + stored_invite = await self.mediation_invite_store.store( + MediationInviteRecord.unused(mediation_invite_url) ) # then @@ -118,3 +114,48 @@ async def test_mark_default_invite_as_used_should_raise_when_no_invite(self): # when - then with self.assertRaises(NoDefaultMediationInviteException): await self.mediation_invite_store.mark_default_invite_as_used() + + async def test_get_mediation_invite_record_returns_none_when_no_invite_available( + self, + ): + # given + self.storage.get_record.side_effect = StorageNotFoundError + + # when + invite = await self.mediation_invite_store.get_mediation_invite_record(None) + + # then + assert invite is None + + async def test_get_mediation_invite_returns_stored_record_when_no_invite_provided( + self, + ): + # given + stored_record = _storage_record_for( + "somepla.ce:4242/alongandunreadablebase64payload" + ) + expected_invite = MediationInviteRecord.from_json(stored_record.value) + self.storage.get_record.return_value = stored_record + + # when + invite = await self.mediation_invite_store.get_mediation_invite_record(None) + + # then + assert invite == expected_invite + + async def test_get_mediation_invite_stores_and_returns_provided_invite_if_none_stored( + self, + ): + # given + expected_invite = MediationInviteRecord.unused( + "somepla.ce:4242/alongandunreadablebase64payload" + ) + self.storage.get_record.return_value = None + + # when + invite = await self.mediation_invite_store.get_mediation_invite_record( + expected_invite.invite + ) + + # then + assert invite == expected_invite