Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix connection record response for mobile #1469

Merged
merged 8 commits into from
Dec 6, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
BaseResponder,
RequestContext,
)
from .....connections.models.conn_record import ConnRecord

from ..manager import ConnectionManager, ConnectionManagerError
from ..messages.connection_request import ConnectionRequest
Expand Down Expand Up @@ -38,11 +39,19 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
mediation_metadata = {}

try:
await mgr.receive_request(
connection = await mgr.receive_request(
context.message,
context.message_receipt,
mediation_id=mediation_metadata.get("id"),
)

if connection.accept == ConnRecord.ACCEPT_AUTO:
response = await mgr.create_response(connection)
await responder.send_reply(
response, connection_id=connection.connection_id
)
else:
self._logger.debug("Connection request will await acceptance")
except ConnectionManagerError as e:
self._logger.exception("Error receiving connection request")
if e.error_code:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ async def test_called(self, mock_conn_mgr, request_context):
)
assert not responder.messages

@pytest.mark.asyncio
@async_mock.patch.object(handler, "ConnectionManager")
async def test_called_with_auto_response(self, mock_conn_mgr, request_context):
mock_conn_rec = async_mock.MagicMock()
mock_conn_rec.accept = ConnRecord.ACCEPT_AUTO
mock_conn_mgr.return_value.receive_request = async_mock.CoroutineMock(
return_value=mock_conn_rec
)
mock_conn_mgr.return_value.create_response = async_mock.CoroutineMock()
request_context.message = ConnectionRequest()
handler_inst = handler.ConnectionRequestHandler()
responder = MockResponder()
await handler_inst.handle(request_context, responder)
mock_conn_mgr.return_value.receive_request.assert_called_once_with(
request_context.message, request_context.message_receipt, mediation_id=None
)
assert responder.messages

@pytest.mark.asyncio
@async_mock.patch.object(handler, "ConnectionManager")
async def test_connection_record_with_mediation_metadata(
Expand Down
15 changes: 0 additions & 15 deletions aries_cloudagent/protocols/connections/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,21 +641,6 @@ async def receive_request(
keylist_updates, connection_id=mediation_record.connection_id
)

if connection.accept == ConnRecord.ACCEPT_AUTO:
response = await self.create_response(connection, mediation_id=mediation_id)
responder = self.profile.inject_or(BaseResponder)
if responder:
await responder.send_reply(
response, connection_id=connection.connection_id
)
async with self.profile.session() as session:
# refetch connection for accurate state
connection = await ConnRecord.retrieve_by_id(
session, connection.connection_id
)
else:
self._logger.debug("Connection request will await acceptance")

return connection

async def create_response(
Expand Down
12 changes: 0 additions & 12 deletions aries_cloudagent/protocols/connections/v1_0/tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,12 +713,6 @@ async def test_receive_request_public_did_oob_invite(self):
conn_rec = await self.manager.receive_request(mock_request, receipt)
assert conn_rec

messages = self.responder.messages
assert len(messages) == 1
(result, target) = messages[0]
assert type(result) == ConnectionResponse
assert "connection_id" in target

async def test_receive_request_public_did_conn_invite(self):
async with self.profile.session() as session:
mock_request = async_mock.MagicMock()
Expand Down Expand Up @@ -755,12 +749,6 @@ async def test_receive_request_public_did_conn_invite(self):
conn_rec = await self.manager.receive_request(mock_request, receipt)
assert conn_rec

messages = self.responder.messages
assert len(messages) == 1
(result, target) = messages[0]
assert type(result) == ConnectionResponse
assert "connection_id" in target

async def test_receive_request_multi_use_multitenant(self):
async with self.profile.session() as session:
multiuse_info = await session.wallet.create_local_did(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Connection request handler under RFC 23 (DID exchange)."""

from .....connections.models.conn_record import ConnRecord
from .....messaging.base_handler import BaseHandler, BaseResponder, RequestContext

from ....problem_report.v1_0.message import ProblemReport
Expand Down Expand Up @@ -33,17 +34,35 @@ async def handle(self, context: RequestContext, responder: BaseResponder):
)
else:
mediation_metadata = {}

mediation_id = mediation_metadata.get("id")
try:
await mgr.receive_request(
conn_rec = await mgr.receive_request(
request=context.message,
recipient_did=context.message_receipt.recipient_did,
recipient_verkey=(
None
if context.message_receipt.recipient_did_public
else context.message_receipt.recipient_verkey
),
mediation_id=mediation_metadata.get("id"),
mediation_id=mediation_id,
)

# Auto respond
if conn_rec.accept == ConnRecord.ACCEPT_AUTO:
response = await mgr.create_response(
conn_rec,
mediation_id=mediation_id,
)
await responder.send_reply(
response, connection_id=conn_rec.connection_id
)
conn_rec.state = ConnRecord.State.RESPONSE.rfc23
async with context.session() as session:
await conn_rec.save(session, reason="Sent connection response")
else:
self._logger.debug("DID exchange request will await acceptance")

except DIDXManagerError as e:
self._logger.exception("Error receiving RFC 23 connection request")
if e.error_code:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,28 @@ async def test_called(self, mock_didx_mgr):
)
assert not responder.messages

@async_mock.patch.object(test_module, "DIDXManager")
async def test_called_with_auto_response(self, mock_didx_mgr):
mock_conn_rec = async_mock.MagicMock()
mock_conn_rec.accept = conn_record.ConnRecord.ACCEPT_AUTO
mock_conn_rec.save = async_mock.CoroutineMock()
mock_didx_mgr.return_value.receive_request = async_mock.CoroutineMock(
return_value=mock_conn_rec
)
mock_didx_mgr.return_value.create_response = async_mock.CoroutineMock()
self.ctx.message = DIDXRequest()
handler_inst = test_module.DIDXRequestHandler()
responder = MockResponder()
await handler_inst.handle(self.ctx, responder)

mock_didx_mgr.return_value.receive_request.assert_called_once_with(
request=self.ctx.message,
recipient_did=self.ctx.message_receipt.recipient_did,
recipient_verkey=None,
mediation_id=None,
)
assert responder.messages

@async_mock.patch.object(test_module, "DIDXManager")
async def test_problem_report(self, mock_didx_mgr):
mock_didx_mgr.return_value.receive_request = async_mock.CoroutineMock(
Expand Down
19 changes: 1 addition & 18 deletions aries_cloudagent/protocols/didexchange/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,28 +549,11 @@ async def receive_request(
# Send keylist updates to mediator
mediation_record = await mediation_record_if_id(self.profile, mediation_id)
if keylist_updates and mediation_record:
responder = self.profile.inject_or(BaseResponder)
responder = self.profile.inject(BaseResponder)
await responder.send(
keylist_updates, connection_id=mediation_record.connection_id
)

if auto_accept:
response = await self.create_response(
conn_rec,
my_endpoint,
mediation_id=mediation_id,
)
responder = self.profile.inject_or(BaseResponder)
if responder:
await responder.send_reply(
response, connection_id=conn_rec.connection_id
)
conn_rec.state = ConnRecord.State.RESPONSE.rfc23
async with self.profile.session() as session:
await conn_rec.save(session, reason="Sent connection response")
else:
self._logger.debug("DID exchange request will await acceptance")

return conn_rec

async def create_response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,11 +540,6 @@ async def test_receive_request_explicit_public_did(self):
)
assert conn_rec

messages = self.responder.messages
assert len(messages) == 2
(result, target) = messages[0]
assert "connection_id" in target

async def test_receive_request_invi_not_found(self):
async with self.profile.session() as session:
mock_request = async_mock.MagicMock(
Expand Down