diff --git a/aries_cloudagent/protocols/connections/v1_0/handlers/connection_request_handler.py b/aries_cloudagent/protocols/connections/v1_0/handlers/connection_request_handler.py index a4c5abbbb6..b1f26a85ae 100644 --- a/aries_cloudagent/protocols/connections/v1_0/handlers/connection_request_handler.py +++ b/aries_cloudagent/protocols/connections/v1_0/handlers/connection_request_handler.py @@ -5,6 +5,7 @@ BaseResponder, RequestContext, ) +from .....connections.models.conn_record import ConnRecord from ..manager import ConnectionManager, ConnectionManagerError from ..messages.connection_request import ConnectionRequest @@ -38,11 +39,19 @@ async def handle(self, context: RequestContext, responder: BaseResponder): mediation_metadata = {} try: - await mgr.receive_request( + connection = await mgr.receive_request( context.message, context.message_receipt, mediation_id=mediation_metadata.get("id"), ) + + if connection.accept == ConnRecord.ACCEPT_AUTO: + response = await mgr.create_response(connection) + await responder.send_reply( + response, connection_id=connection.connection_id + ) + else: + self._logger.debug("Connection request will await acceptance") except ConnectionManagerError as e: self._logger.exception("Error receiving connection request") if e.error_code: diff --git a/aries_cloudagent/protocols/connections/v1_0/handlers/tests/test_request_handler.py b/aries_cloudagent/protocols/connections/v1_0/handlers/tests/test_request_handler.py index b133e7dfdb..ec16a72cef 100644 --- a/aries_cloudagent/protocols/connections/v1_0/handlers/tests/test_request_handler.py +++ b/aries_cloudagent/protocols/connections/v1_0/handlers/tests/test_request_handler.py @@ -87,6 +87,24 @@ async def test_called(self, mock_conn_mgr, request_context): ) assert not responder.messages + @pytest.mark.asyncio + @async_mock.patch.object(handler, "ConnectionManager") + async def test_called_with_auto_response(self, mock_conn_mgr, request_context): + mock_conn_rec = async_mock.MagicMock() + mock_conn_rec.accept = ConnRecord.ACCEPT_AUTO + mock_conn_mgr.return_value.receive_request = async_mock.CoroutineMock( + return_value=mock_conn_rec + ) + mock_conn_mgr.return_value.create_response = async_mock.CoroutineMock() + request_context.message = ConnectionRequest() + handler_inst = handler.ConnectionRequestHandler() + responder = MockResponder() + await handler_inst.handle(request_context, responder) + mock_conn_mgr.return_value.receive_request.assert_called_once_with( + request_context.message, request_context.message_receipt, mediation_id=None + ) + assert responder.messages + @pytest.mark.asyncio @async_mock.patch.object(handler, "ConnectionManager") async def test_connection_record_with_mediation_metadata( diff --git a/aries_cloudagent/protocols/connections/v1_0/manager.py b/aries_cloudagent/protocols/connections/v1_0/manager.py index a2441132ad..5841945bb6 100644 --- a/aries_cloudagent/protocols/connections/v1_0/manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/manager.py @@ -641,21 +641,6 @@ async def receive_request( keylist_updates, connection_id=mediation_record.connection_id ) - if connection.accept == ConnRecord.ACCEPT_AUTO: - response = await self.create_response(connection, mediation_id=mediation_id) - responder = self.profile.inject_or(BaseResponder) - if responder: - await responder.send_reply( - response, connection_id=connection.connection_id - ) - async with self.profile.session() as session: - # refetch connection for accurate state - connection = await ConnRecord.retrieve_by_id( - session, connection.connection_id - ) - else: - self._logger.debug("Connection request will await acceptance") - return connection async def create_response( diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py index 33743caa1f..e3b88a1dd5 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py @@ -713,12 +713,6 @@ async def test_receive_request_public_did_oob_invite(self): conn_rec = await self.manager.receive_request(mock_request, receipt) assert conn_rec - messages = self.responder.messages - assert len(messages) == 1 - (result, target) = messages[0] - assert type(result) == ConnectionResponse - assert "connection_id" in target - async def test_receive_request_public_did_conn_invite(self): async with self.profile.session() as session: mock_request = async_mock.MagicMock() @@ -755,12 +749,6 @@ async def test_receive_request_public_did_conn_invite(self): conn_rec = await self.manager.receive_request(mock_request, receipt) assert conn_rec - messages = self.responder.messages - assert len(messages) == 1 - (result, target) = messages[0] - assert type(result) == ConnectionResponse - assert "connection_id" in target - async def test_receive_request_multi_use_multitenant(self): async with self.profile.session() as session: multiuse_info = await session.wallet.create_local_did( diff --git a/aries_cloudagent/protocols/didexchange/v1_0/handlers/request_handler.py b/aries_cloudagent/protocols/didexchange/v1_0/handlers/request_handler.py index 1da1a90261..b7f036e11f 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/handlers/request_handler.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/handlers/request_handler.py @@ -1,5 +1,6 @@ """Connection request handler under RFC 23 (DID exchange).""" +from .....connections.models.conn_record import ConnRecord from .....messaging.base_handler import BaseHandler, BaseResponder, RequestContext from ....problem_report.v1_0.message import ProblemReport @@ -33,8 +34,10 @@ async def handle(self, context: RequestContext, responder: BaseResponder): ) else: mediation_metadata = {} + + mediation_id = mediation_metadata.get("id") try: - await mgr.receive_request( + conn_rec = await mgr.receive_request( request=context.message, recipient_did=context.message_receipt.recipient_did, recipient_verkey=( @@ -42,8 +45,24 @@ async def handle(self, context: RequestContext, responder: BaseResponder): if context.message_receipt.recipient_did_public else context.message_receipt.recipient_verkey ), - mediation_id=mediation_metadata.get("id"), + mediation_id=mediation_id, ) + + # Auto respond + if conn_rec.accept == ConnRecord.ACCEPT_AUTO: + response = await mgr.create_response( + conn_rec, + mediation_id=mediation_id, + ) + await responder.send_reply( + response, connection_id=conn_rec.connection_id + ) + conn_rec.state = ConnRecord.State.RESPONSE.rfc23 + async with context.session() as session: + await conn_rec.save(session, reason="Sent connection response") + else: + self._logger.debug("DID exchange request will await acceptance") + except DIDXManagerError as e: self._logger.exception("Error receiving RFC 23 connection request") if e.error_code: diff --git a/aries_cloudagent/protocols/didexchange/v1_0/handlers/tests/test_request_handler.py b/aries_cloudagent/protocols/didexchange/v1_0/handlers/tests/test_request_handler.py index 409f8387c9..21eb4c2688 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/handlers/tests/test_request_handler.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/handlers/tests/test_request_handler.py @@ -160,6 +160,28 @@ async def test_called(self, mock_didx_mgr): ) assert not responder.messages + @async_mock.patch.object(test_module, "DIDXManager") + async def test_called_with_auto_response(self, mock_didx_mgr): + mock_conn_rec = async_mock.MagicMock() + mock_conn_rec.accept = conn_record.ConnRecord.ACCEPT_AUTO + mock_conn_rec.save = async_mock.CoroutineMock() + mock_didx_mgr.return_value.receive_request = async_mock.CoroutineMock( + return_value=mock_conn_rec + ) + mock_didx_mgr.return_value.create_response = async_mock.CoroutineMock() + self.ctx.message = DIDXRequest() + handler_inst = test_module.DIDXRequestHandler() + responder = MockResponder() + await handler_inst.handle(self.ctx, responder) + + mock_didx_mgr.return_value.receive_request.assert_called_once_with( + request=self.ctx.message, + recipient_did=self.ctx.message_receipt.recipient_did, + recipient_verkey=None, + mediation_id=None, + ) + assert responder.messages + @async_mock.patch.object(test_module, "DIDXManager") async def test_problem_report(self, mock_didx_mgr): mock_didx_mgr.return_value.receive_request = async_mock.CoroutineMock( diff --git a/aries_cloudagent/protocols/didexchange/v1_0/manager.py b/aries_cloudagent/protocols/didexchange/v1_0/manager.py index 0e123b3961..cd4f0d2079 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/manager.py @@ -549,28 +549,11 @@ async def receive_request( # Send keylist updates to mediator mediation_record = await mediation_record_if_id(self.profile, mediation_id) if keylist_updates and mediation_record: - responder = self.profile.inject_or(BaseResponder) + responder = self.profile.inject(BaseResponder) await responder.send( keylist_updates, connection_id=mediation_record.connection_id ) - if auto_accept: - response = await self.create_response( - conn_rec, - my_endpoint, - mediation_id=mediation_id, - ) - responder = self.profile.inject_or(BaseResponder) - if responder: - await responder.send_reply( - response, connection_id=conn_rec.connection_id - ) - conn_rec.state = ConnRecord.State.RESPONSE.rfc23 - async with self.profile.session() as session: - await conn_rec.save(session, reason="Sent connection response") - else: - self._logger.debug("DID exchange request will await acceptance") - return conn_rec async def create_response( diff --git a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py index 9a8049b1d0..13c7f169f4 100644 --- a/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/didexchange/v1_0/tests/test_manager.py @@ -540,11 +540,6 @@ async def test_receive_request_explicit_public_did(self): ) assert conn_rec - messages = self.responder.messages - assert len(messages) == 2 - (result, target) = messages[0] - assert "connection_id" in target - async def test_receive_request_invi_not_found(self): async with self.profile.session() as session: mock_request = async_mock.MagicMock(