From f12df9d504274816eac59200955e4a014050243d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Humbert?= Date: Wed, 1 Sep 2021 16:50:45 +0200 Subject: [PATCH] Refactoring step before introducing changes in mediation connection setup. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Align `OutOfBandManager.receive_invitation` signature with the `ConnectionManager.receive_invitation` and `DIDXManager.receive_invitation` * rename first parameter to `invitation` * change return type to `ConnRecord` and leave serialization work for users of the manager * Serialize the received `ConnRecord` in `protocols.out_of_band.v1_0.routes` * Adapt related tests. `assert ConnRecord.deserialize(conn_rec)` is dubious as a test assertion but boiled down to `assert conn_rec is not None`. * Use this refactoring to reduce branching in mediator connection handling in `conductor.py`. Signed-off-by: Clément Humbert --- aries_cloudagent/core/conductor.py | 62 ++++++++--------- aries_cloudagent/core/tests/test_conductor.py | 3 +- .../protocols/out_of_band/v1_0/manager.py | 42 ++++++------ .../protocols/out_of_band/v1_0/routes.py | 2 +- .../out_of_band/v1_0/tests/test_manager.py | 68 +++++++------------ .../out_of_band/v1_0/tests/test_routes.py | 14 ++-- 6 files changed, 84 insertions(+), 107 deletions(-) diff --git a/aries_cloudagent/core/conductor.py b/aries_cloudagent/core/conductor.py index 4b06ed3db9..d57b0ad6a4 100644 --- a/aries_cloudagent/core/conductor.py +++ b/aries_cloudagent/core/conductor.py @@ -19,7 +19,6 @@ from ..config.ledger import get_genesis_transactions, ledger_config from ..config.logging import LoggingConfigurator from ..config.wallet import wallet_config -from ..connections.models.conn_record import ConnRecord from ..core.profile import Profile from ..ledger.error import LedgerConfigError, LedgerTransactionError from ..messaging.responder import BaseResponder @@ -303,45 +302,38 @@ async def start(self) -> None: LOGGER.exception("Error creating invitation") # Accept mediation invitation if specified - mediation_invitation = context.settings.get("mediation.invite") + mediation_invitation: str = context.settings.get("mediation.invite") if mediation_invitation: try: mediation_connections_invite = context.settings.get( "mediation.connections_invite", False ) - if mediation_connections_invite: - async with self.root_profile.session() as session: - mgr = ConnectionManager(session) - conn_record = await mgr.receive_invitation( - invitation=ConnectionInvitation.from_url( - mediation_invitation - ), - auto_accept=True, - ) - await conn_record.metadata_set( - session, MediationManager.SEND_REQ_AFTER_CONNECTION, True - ) - await conn_record.metadata_set( - session, MediationManager.SET_TO_DEFAULT_ON_GRANTED, True - ) - print("Attempting to connect to mediator...") - del mgr - else: - async with self.root_profile.session() as session: - mgr = OutOfBandManager(session) - conn_record_dict = await mgr.receive_invitation( - invi_msg=InvitationMessage.from_url(mediation_invitation), - auto_accept=True, - ) - conn_record = ConnRecord.deserialize(conn_record_dict) - await conn_record.metadata_set( - session, MediationManager.SEND_REQ_AFTER_CONNECTION, True - ) - await conn_record.metadata_set( - session, MediationManager.SET_TO_DEFAULT_ON_GRANTED, True - ) - print("Attempting to connect to mediator...") - del mgr + invitation_handler = ( + ConnectionInvitation + if mediation_connections_invite + else InvitationMessage + ) + + async with self.root_profile.session() as session: + mgr = ( + ConnectionManager(session) + if mediation_connections_invite + else OutOfBandManager(session) + ) + + conn_record = await mgr.receive_invitation( + invitation=invitation_handler.from_url(mediation_invitation), + auto_accept=True, + ) + + await conn_record.metadata_set( + session, MediationManager.SEND_REQ_AFTER_CONNECTION, True + ) + await conn_record.metadata_set( + session, MediationManager.SET_TO_DEFAULT_ON_GRANTED, True + ) + print("Attempting to connect to mediator...") + del mgr except Exception: LOGGER.exception("Error accepting mediation invitation") diff --git a/aries_cloudagent/core/tests/test_conductor.py b/aries_cloudagent/core/tests/test_conductor.py index 8e97afbb1c..705d2ebc74 100644 --- a/aries_cloudagent/core/tests/test_conductor.py +++ b/aries_cloudagent/core/tests/test_conductor.py @@ -775,7 +775,6 @@ async def test_mediator_invitation_0434(self): ) conn_record.accept = ConnRecord.ACCEPT_MANUAL await conn_record.save(await conductor.root_profile.session()) - conn_record_dict = conn_record.serialize() with async_mock.patch.object( test_module.InvitationMessage, "from_url" ) as mock_from_url, async_mock.patch.object( @@ -784,7 +783,7 @@ async def test_mediator_invitation_0434(self): async_mock.MagicMock( return_value=async_mock.MagicMock( receive_invitation=async_mock.CoroutineMock( - return_value=conn_record_dict + return_value=conn_record ) ) ), diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py index 374f1bcba2..4b846c2192 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py @@ -363,17 +363,17 @@ async def create_invitation( async def receive_invitation( self, - invi_msg: InvitationMessage, + invitation: InvitationMessage, use_existing_connection: bool = True, auto_accept: bool = None, alias: str = None, mediation_id: str = None, - ) -> dict: + ) -> ConnRecord: """ Receive an out of band invitation message. Args: - invi_msg: invitation message + invitation: invitation message use_existing_connection: whether to use existing connection if possible auto_accept: whether to accept the invitation automatically alias: Alias for connection record @@ -390,15 +390,15 @@ async def receive_invitation( mediation_id = None # There must be exactly 1 service entry - if len(invi_msg.services) != 1: + if len(invitation.services) != 1: raise OutOfBandManagerError("service array must have exactly one element") - if not (invi_msg.requests_attach or invi_msg.handshake_protocols): + if not (invitation.requests_attach or invitation.handshake_protocols): raise OutOfBandManagerError( "Invitation must specify handshake_protocols, requests_attach, or both" ) # Get the single service item - oob_service_item = invi_msg.services[0] + oob_service_item = invitation.services[0] if isinstance(oob_service_item, ServiceMessage): service = oob_service_item public_did = None @@ -437,7 +437,7 @@ async def receive_invitation( for hsp in dict.fromkeys( [ DIDCommPrefix.unqualify(proto) - for proto in invi_msg.handshake_protocols + for proto in invitation.handshake_protocols ] ) ] @@ -454,7 +454,7 @@ async def receive_invitation( ) if conn_rec is not None: num_included_protocols = len(unq_handshake_protos) - num_included_req_attachments = len(invi_msg.requests_attach) + num_included_req_attachments = len(invitation.requests_attach) # With handshake protocol, request attachment; use existing connection if ( num_included_protocols >= 1 @@ -462,7 +462,7 @@ async def receive_invitation( and use_existing_connection ): await self.create_handshake_reuse_message( - invi_msg=invi_msg, + invi_msg=invitation, conn_record=conn_rec, ) try: @@ -526,7 +526,7 @@ async def receive_invitation( if proto is HSProto.RFC23: didx_mgr = DIDXManager(self._session) conn_rec = await didx_mgr.receive_invitation( - invitation=invi_msg, + invitation=invitation, their_public_did=public_did, auto_accept=auto_accept, alias=alias, @@ -543,9 +543,9 @@ async def receive_invitation( ] or [] connection_invitation = ConnectionInvitation.deserialize( { - "@id": invi_msg._id, + "@id": invitation._id, "@type": DIDCommPrefix.qualify_current(proto.name), - "label": invi_msg.label, + "label": invitation.label, "recipientKeys": service.recipient_keys, "serviceEndpoint": service.service_endpoint, "routingKeys": service.routing_keys, @@ -563,8 +563,8 @@ async def receive_invitation( break # Request Attach - if len(invi_msg.requests_attach) >= 1 and conn_rec is not None: - req_attach = invi_msg.requests_attach[0] + if len(invitation.requests_attach) >= 1 and conn_rec is not None: + req_attach = invitation.requests_attach[0] if isinstance(req_attach, AttachDecorator): if req_attach.data is not None: unq_req_attach_type = DIDCommPrefix.unqualify( @@ -575,14 +575,14 @@ async def receive_invitation( req_attach=req_attach, service=service, conn_rec=conn_rec, - trace=(invi_msg._trace is not None), + trace=(invitation._trace is not None), ) elif unq_req_attach_type == PRES_20_REQUEST: await self._process_pres_request_v2( req_attach=req_attach, service=service, conn_rec=conn_rec, - trace=(invi_msg._trace is not None), + trace=(invitation._trace is not None), ) elif unq_req_attach_type == CREDENTIAL_OFFER: if auto_accept or self._session.settings.get( @@ -597,12 +597,12 @@ async def receive_invitation( LOGGER.warning( "Connection not ready to receive credential, " f"For connection_id:{conn_rec.connection_id} and " - f"invitation_msg_id {invi_msg._id}", + f"invitation_msg_id {invitation._id}", ) await self._process_cred_offer_v1( req_attach=req_attach, conn_rec=conn_rec, - trace=(invi_msg._trace is not None), + trace=(invitation._trace is not None), ) elif unq_req_attach_type == CRED_20_OFFER: if auto_accept or self._session.settings.get( @@ -617,12 +617,12 @@ async def receive_invitation( LOGGER.warning( "Connection not ready to receive credential, " f"For connection_id:{conn_rec.connection_id} and " - f"invitation_msg_id {invi_msg._id}", + f"invitation_msg_id {invitation._id}", ) await self._process_cred_offer_v2( req_attach=req_attach, conn_rec=conn_rec, - trace=(invi_msg._trace is not None), + trace=(invitation._trace is not None), ) else: raise OutOfBandManagerError( @@ -636,7 +636,7 @@ async def receive_invitation( else: raise OutOfBandManagerError("requests~attach is not properly formatted") - return conn_rec.serialize() + return conn_rec async def _process_pres_request_v1( self, diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/routes.py b/aries_cloudagent/protocols/out_of_band/v1_0/routes.py index 8903f458a1..388da7500d 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/routes.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/routes.py @@ -227,7 +227,7 @@ async def invitation_receive(request: web.BaseRequest): except (DIDXManagerError, StorageError, BaseModelError) as err: raise web.HTTPBadRequest(reason=err.roll_up) from err - return web.json_response(result) + return web.json_response(result.serialize()) async def register(app: web.Application): diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py index 6173705576..8c17a2fa03 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_manager.py @@ -852,7 +852,7 @@ async def test_dif_req_v2_attach_pres_existing_conn_auto_present_pres_msg_with_c conn_rec = await self.manager.receive_invitation( mock_oob_invi, use_existing_connection=True ) - assert ConnRecord.deserialize(conn_rec) + assert conn_rec is not None async def test_dif_req_v2_attach_pres_existing_conn_auto_present_pres_msg_with_nonce( self, @@ -1018,7 +1018,7 @@ async def test_dif_req_v2_attach_pres_existing_conn_auto_present_pres_msg_with_n conn_rec = await self.manager.receive_invitation( mock_oob_invi, use_existing_connection=True ) - assert ConnRecord.deserialize(conn_rec) + assert conn_rec is not None async def test_create_invitation_public_x_no_public_invites(self): self.session.context.update_settings({"public_invites": False}) @@ -1185,7 +1185,7 @@ async def test_receive_invitation_with_valid_mediation(self): ) invi_msg = invite.invitation invitee_record = await self.manager.receive_invitation( - invi_msg=invi_msg, + invitation=invi_msg, mediation_id=mediation_record._id, ) mock_didx_recv_invi.assert_called_once_with( @@ -1300,14 +1300,14 @@ async def test_receive_invitation_connection(self): ) result = await self.manager.receive_invitation( - invi_msg=oob_invi_rec.invitation, + invitation=oob_invi_rec.invitation, use_existing_connection=True, auto_accept=True, ) - connection_id = UUID(result.get("connection_id"), version=4) + connection_id = UUID(result.connection_id, version=4) assert ( - connection_id.hex == result.get("connection_id").replace("-", "") - and len(result.get("connection_id")) > 5 + connection_id.hex == result.connection_id.replace("-", "") + and len(result.connection_id) > 5 ) async def test_receive_invitation_services_with_neither_service_blocks_nor_dids( @@ -1397,10 +1397,10 @@ async def test_receive_invitation_req_pres_v1_0_attachment_x(self): with self.assertRaises(OutOfBandManagerError) as context: result = await self.manager.receive_invitation(mock_oob_invi) - connection_id = UUID(result.get("connection_id"), version=4) + connection_id = UUID(result.connection_id, version=4) assert ( - connection_id.hex == result.get("connection_id") - and len(result.get("connection_id")) > 5 + connection_id.hex == result.connection_id + and len(result.connection_id) > 5 ) assert "requests~attach is not properly formatted" in str(context.exception) @@ -2016,9 +2016,7 @@ async def test_existing_conn_record_public_did(self): ) is None ) - assert ( - result.get("connection_id") == retrieved_conn_records[0].connection_id - ) + assert result.connection_id == retrieved_conn_records[0].connection_id async def test_existing_conn_record_public_did_not_accepted(self): self.session.context.update_settings({"public_invites": True}) @@ -2117,9 +2115,7 @@ async def test_existing_conn_record_public_did_not_accepted(self): ) == "not_accepted" ) - assert ( - result.get("connection_id") != retrieved_conn_records[0].connection_id - ) + assert result.connection_id != retrieved_conn_records[0].connection_id async def test_existing_conn_record_public_did_inverse_cases(self): self.session.context.update_settings({"public_invites": True}) @@ -2134,11 +2130,6 @@ async def test_existing_conn_record_public_did_inverse_cases(self): await test_exist_conn.save(self.session) await test_exist_conn.metadata_set(self.session, "reuse_msg_state", "initial") await test_exist_conn.metadata_set(self.session, "reuse_msg_id", "test_123") - receipt = MessageReceipt( - recipient_did=TestConfig.test_did, - recipient_did_public=False, - sender_did=TestConfig.test_target_did, - ) with async_mock.patch.object( DIDXManager, "receive_invitation", autospec=True @@ -2180,9 +2171,7 @@ async def test_existing_conn_record_public_did_inverse_cases(self): }, alt=True, ) - assert ( - result.get("connection_id") != retrieved_conn_records[0].connection_id - ) + assert result.connection_id != retrieved_conn_records[0].connection_id async def test_existing_conn_record_public_did_timeout(self): self.session.context.update_settings({"public_invites": True}) @@ -2475,7 +2464,7 @@ async def test_req_v1_attach_presentation_existing_conn_auto_present_pres_msg(se conn_rec = await self.manager.receive_invitation( mock_oob_invi, use_existing_connection=True ) - assert ConnRecord.deserialize(conn_rec) + assert conn_rec is not None async def test_req_v1_attach_pres_catch_value_error(self): self.session.context.update_settings({"public_invites": True}) @@ -2780,7 +2769,7 @@ async def test_req_v2_attach_presentation_existing_conn_auto_present_pres_msg(se conn_rec = await self.manager.receive_invitation( mock_oob_invi, use_existing_connection=True ) - assert ConnRecord.deserialize(conn_rec) + assert conn_rec is not None async def test_req_v2_attach_pres_catch_value_error(self): self.session.context.update_settings({"public_invites": True}) @@ -2989,7 +2978,7 @@ async def test_req_attach_cred_offer_v1(self): conn_rec = await self.manager.receive_invitation( mock_oob_invi, use_existing_connection=True ) - assert ConnRecord.deserialize(conn_rec) + assert conn_rec is not None async def test_req_attach_cred_offer_v1_no_issue(self): self.session.context.update_settings({"public_invites": True}) @@ -3172,7 +3161,7 @@ async def test_req_attach_cred_offer_v2(self): conn_rec = await self.manager.receive_invitation( mock_oob_invi, use_existing_connection=True ) - assert ConnRecord.deserialize(conn_rec) + assert conn_rec is not None async def test_req_attach_cred_offer_v2_no_issue(self): self.session.context.update_settings({"public_invites": True}) @@ -3382,11 +3371,6 @@ async def test_request_attach_cred_offer_v1_check_conn_rec_active_timeout(self): await test_exist_conn.metadata_set(self.session, "reuse_msg_state", "initial") await test_exist_conn.metadata_set(self.session, "reuse_msg_id", "test_123") - receipt = MessageReceipt( - recipient_did=TestConfig.test_did, - recipient_did_public=False, - sender_did=TestConfig.test_target_did, - ) req_attach = deepcopy(TestConfig.req_attach_v1) del req_attach["data"]["json"] req_attach["data"]["json"] = TestConfig.CRED_OFFER_V1.serialize() @@ -3396,7 +3380,7 @@ async def test_request_attach_cred_offer_v1_check_conn_rec_active_timeout(self): DIDXManager, "receive_invitation", autospec=True, - ) as didx_mgr_receive_invitation, async_mock.patch.object( + ), async_mock.patch.object( V10CredManager, "receive_offer", autospec=True, @@ -3407,7 +3391,7 @@ async def test_request_attach_cred_offer_v1_check_conn_rec_active_timeout(self): OutOfBandManager, "fetch_connection_targets", autospec=True, - ) as oob_mgr_fetch_conn, async_mock.patch.object( + ), async_mock.patch.object( OutOfBandManager, "find_existing_connection", autospec=True, @@ -3415,7 +3399,7 @@ async def test_request_attach_cred_offer_v1_check_conn_rec_active_timeout(self): OutOfBandManager, "check_reuse_msg_state", autospec=True, - ) as oob_mgr_check_reuse_state, async_mock.patch.object( + ), async_mock.patch.object( OutOfBandManager, "conn_rec_is_active", autospec=True, @@ -3423,19 +3407,19 @@ async def test_request_attach_cred_offer_v1_check_conn_rec_active_timeout(self): OutOfBandManager, "create_handshake_reuse_message", autospec=True, - ) as oob_mgr_create_reuse_msg, async_mock.patch.object( + ), async_mock.patch.object( OutOfBandManager, "receive_reuse_message", autospec=True, - ) as oob_mgr_receive_reuse_msg, async_mock.patch.object( + ), async_mock.patch.object( OutOfBandManager, "receive_reuse_accepted_message", autospec=True, - ) as oob_mgr_receive_accept_msg, async_mock.patch.object( + ), async_mock.patch.object( OutOfBandManager, "receive_problem_report", autospec=True, - ) as oob_mgr_receive_problem_report, async_mock.patch.object( + ), async_mock.patch.object( V10CredManager, "create_request", autospec=True, @@ -3458,7 +3442,7 @@ async def test_request_attach_cred_offer_v1_check_conn_rec_active_timeout(self): mock_oob_invi, use_existing_connection=True ) mock_logger_warning.assert_called_once() - assert ConnRecord.deserialize(conn_rec) + assert conn_rec is not None async def test_request_attach_cred_offer_v2_check_conn_rec_active_timeout(self): self.session.context.update_settings({"public_invites": True}) @@ -3552,4 +3536,4 @@ async def test_request_attach_cred_offer_v2_check_conn_rec_active_timeout(self): mock_oob_invi, use_existing_connection=True ) mock_logger_warning.assert_called_once() - assert ConnRecord.deserialize(conn_rec) + assert conn_rec is not None diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py index cf7a29795e..d7aaffabf8 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/tests/test_routes.py @@ -1,9 +1,8 @@ -import json - from asynctest import TestCase as AsyncTestCase from asynctest import mock as async_mock from .....admin.request_context import AdminRequestContext +from .....connections.models.conn_record import ConnRecord from .. import routes as test_module @@ -86,20 +85,23 @@ async def test_invitation_create_x(self): async def test_invitation_receive(self): self.request.json = async_mock.CoroutineMock() + expected_connection_record = ConnRecord(connection_id="some-id") with async_mock.patch.object( test_module, "OutOfBandManager", autospec=True ) as mock_oob_mgr, async_mock.patch.object( test_module.InvitationMessage, "deserialize", async_mock.Mock() - ) as mock_invi_deser, async_mock.patch.object( + ), async_mock.patch.object( test_module.web, "json_response", async_mock.Mock() ) as mock_json_response: mock_oob_mgr.return_value.receive_invitation = async_mock.CoroutineMock( - return_value={"abc": "123"} + return_value=expected_connection_record ) - result = await test_module.invitation_receive(self.request) - mock_json_response.assert_called_once_with({"abc": "123"}) + await test_module.invitation_receive(self.request) + mock_json_response.assert_called_once_with( + expected_connection_record.serialize() + ) async def test_invitation_receive_forbidden_x(self): self.context.update_settings({"admin.no_receive_invites": True})