Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align OutOfBandManager.receive_invitation with other connection managers #1382

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 27 additions & 35 deletions aries_cloudagent/core/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ..config.ledger import get_genesis_transactions, ledger_config
from ..config.logging import LoggingConfigurator
from ..config.wallet import wallet_config
from ..connections.models.conn_record import ConnRecord
from ..core.profile import Profile
from ..ledger.error import LedgerConfigError, LedgerTransactionError
from ..messaging.responder import BaseResponder
Expand Down Expand Up @@ -303,45 +302,38 @@ async def start(self) -> None:
LOGGER.exception("Error creating invitation")

# Accept mediation invitation if specified
mediation_invitation = context.settings.get("mediation.invite")
mediation_invitation: str = context.settings.get("mediation.invite")
if mediation_invitation:
try:
mediation_connections_invite = context.settings.get(
"mediation.connections_invite", False
)
if mediation_connections_invite:
async with self.root_profile.session() as session:
mgr = ConnectionManager(session)
conn_record = await mgr.receive_invitation(
invitation=ConnectionInvitation.from_url(
mediation_invitation
),
auto_accept=True,
)
await conn_record.metadata_set(
session, MediationManager.SEND_REQ_AFTER_CONNECTION, True
)
await conn_record.metadata_set(
session, MediationManager.SET_TO_DEFAULT_ON_GRANTED, True
)
print("Attempting to connect to mediator...")
del mgr
else:
async with self.root_profile.session() as session:
mgr = OutOfBandManager(session)
conn_record_dict = await mgr.receive_invitation(
invi_msg=InvitationMessage.from_url(mediation_invitation),
auto_accept=True,
)
conn_record = ConnRecord.deserialize(conn_record_dict)
await conn_record.metadata_set(
session, MediationManager.SEND_REQ_AFTER_CONNECTION, True
)
await conn_record.metadata_set(
session, MediationManager.SET_TO_DEFAULT_ON_GRANTED, True
)
print("Attempting to connect to mediator...")
del mgr
invitation_handler = (
ConnectionInvitation
if mediation_connections_invite
else InvitationMessage
)

async with self.root_profile.session() as session:
mgr = (
ConnectionManager(session)
if mediation_connections_invite
else OutOfBandManager(session)
)

conn_record = await mgr.receive_invitation(
invitation=invitation_handler.from_url(mediation_invitation),
auto_accept=True,
)

await conn_record.metadata_set(
session, MediationManager.SEND_REQ_AFTER_CONNECTION, True
)
await conn_record.metadata_set(
session, MediationManager.SET_TO_DEFAULT_ON_GRANTED, True
)
print("Attempting to connect to mediator...")
del mgr
except Exception:
LOGGER.exception("Error accepting mediation invitation")

Expand Down
3 changes: 1 addition & 2 deletions aries_cloudagent/core/tests/test_conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,6 @@ async def test_mediator_invitation_0434(self):
)
conn_record.accept = ConnRecord.ACCEPT_MANUAL
await conn_record.save(await conductor.root_profile.session())
conn_record_dict = conn_record.serialize()
with async_mock.patch.object(
test_module.InvitationMessage, "from_url"
) as mock_from_url, async_mock.patch.object(
Expand All @@ -784,7 +783,7 @@ async def test_mediator_invitation_0434(self):
async_mock.MagicMock(
return_value=async_mock.MagicMock(
receive_invitation=async_mock.CoroutineMock(
return_value=conn_record_dict
return_value=conn_record
)
)
),
Expand Down
42 changes: 21 additions & 21 deletions aries_cloudagent/protocols/out_of_band/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,17 +363,17 @@ async def create_invitation(

async def receive_invitation(
self,
invi_msg: InvitationMessage,
invitation: InvitationMessage,
use_existing_connection: bool = True,
auto_accept: bool = None,
alias: str = None,
mediation_id: str = None,
) -> dict:
) -> ConnRecord:
"""
Receive an out of band invitation message.

Args:
invi_msg: invitation message
invitation: invitation message
use_existing_connection: whether to use existing connection if possible
auto_accept: whether to accept the invitation automatically
alias: Alias for connection record
Expand All @@ -390,15 +390,15 @@ async def receive_invitation(
mediation_id = None

# There must be exactly 1 service entry
if len(invi_msg.services) != 1:
if len(invitation.services) != 1:
raise OutOfBandManagerError("service array must have exactly one element")

if not (invi_msg.requests_attach or invi_msg.handshake_protocols):
if not (invitation.requests_attach or invitation.handshake_protocols):
raise OutOfBandManagerError(
"Invitation must specify handshake_protocols, requests_attach, or both"
)
# Get the single service item
oob_service_item = invi_msg.services[0]
oob_service_item = invitation.services[0]
if isinstance(oob_service_item, ServiceMessage):
service = oob_service_item
public_did = None
Expand Down Expand Up @@ -437,7 +437,7 @@ async def receive_invitation(
for hsp in dict.fromkeys(
[
DIDCommPrefix.unqualify(proto)
for proto in invi_msg.handshake_protocols
for proto in invitation.handshake_protocols
]
)
]
Expand All @@ -454,15 +454,15 @@ async def receive_invitation(
)
if conn_rec is not None:
num_included_protocols = len(unq_handshake_protos)
num_included_req_attachments = len(invi_msg.requests_attach)
num_included_req_attachments = len(invitation.requests_attach)
# With handshake protocol, request attachment; use existing connection
if (
num_included_protocols >= 1
and num_included_req_attachments == 0
and use_existing_connection
):
await self.create_handshake_reuse_message(
invi_msg=invi_msg,
invi_msg=invitation,
conn_record=conn_rec,
)
try:
Expand Down Expand Up @@ -526,7 +526,7 @@ async def receive_invitation(
if proto is HSProto.RFC23:
didx_mgr = DIDXManager(self._session)
conn_rec = await didx_mgr.receive_invitation(
invitation=invi_msg,
invitation=invitation,
their_public_did=public_did,
auto_accept=auto_accept,
alias=alias,
Expand All @@ -543,9 +543,9 @@ async def receive_invitation(
] or []
connection_invitation = ConnectionInvitation.deserialize(
{
"@id": invi_msg._id,
"@id": invitation._id,
"@type": DIDCommPrefix.qualify_current(proto.name),
"label": invi_msg.label,
"label": invitation.label,
"recipientKeys": service.recipient_keys,
"serviceEndpoint": service.service_endpoint,
"routingKeys": service.routing_keys,
Expand All @@ -563,8 +563,8 @@ async def receive_invitation(
break

# Request Attach
if len(invi_msg.requests_attach) >= 1 and conn_rec is not None:
req_attach = invi_msg.requests_attach[0]
if len(invitation.requests_attach) >= 1 and conn_rec is not None:
req_attach = invitation.requests_attach[0]
if isinstance(req_attach, AttachDecorator):
if req_attach.data is not None:
unq_req_attach_type = DIDCommPrefix.unqualify(
Expand All @@ -575,14 +575,14 @@ async def receive_invitation(
req_attach=req_attach,
service=service,
conn_rec=conn_rec,
trace=(invi_msg._trace is not None),
trace=(invitation._trace is not None),
)
elif unq_req_attach_type == PRES_20_REQUEST:
await self._process_pres_request_v2(
req_attach=req_attach,
service=service,
conn_rec=conn_rec,
trace=(invi_msg._trace is not None),
trace=(invitation._trace is not None),
)
elif unq_req_attach_type == CREDENTIAL_OFFER:
if auto_accept or self._session.settings.get(
Expand All @@ -597,12 +597,12 @@ async def receive_invitation(
LOGGER.warning(
"Connection not ready to receive credential, "
f"For connection_id:{conn_rec.connection_id} and "
f"invitation_msg_id {invi_msg._id}",
f"invitation_msg_id {invitation._id}",
)
await self._process_cred_offer_v1(
req_attach=req_attach,
conn_rec=conn_rec,
trace=(invi_msg._trace is not None),
trace=(invitation._trace is not None),
)
elif unq_req_attach_type == CRED_20_OFFER:
if auto_accept or self._session.settings.get(
Expand All @@ -617,12 +617,12 @@ async def receive_invitation(
LOGGER.warning(
"Connection not ready to receive credential, "
f"For connection_id:{conn_rec.connection_id} and "
f"invitation_msg_id {invi_msg._id}",
f"invitation_msg_id {invitation._id}",
)
await self._process_cred_offer_v2(
req_attach=req_attach,
conn_rec=conn_rec,
trace=(invi_msg._trace is not None),
trace=(invitation._trace is not None),
)
else:
raise OutOfBandManagerError(
Expand All @@ -636,7 +636,7 @@ async def receive_invitation(
else:
raise OutOfBandManagerError("requests~attach is not properly formatted")

return conn_rec.serialize()
return conn_rec

async def _process_pres_request_v1(
self,
Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/protocols/out_of_band/v1_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ async def invitation_receive(request: web.BaseRequest):
except (DIDXManagerError, StorageError, BaseModelError) as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response(result)
return web.json_response(result.serialize())


async def register(app: web.Application):
Expand Down
Loading