Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make mediation invitation parameter idempotent #1413

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions aries_cloudagent/commands/provision.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from ..config.ledger import get_genesis_transactions, ledger_config
from ..config.util import common_config
from ..config.wallet import wallet_config
from ..protocols.coordinate_mediation.mediation_invite_store import (
MediationInviteStore,
MediationInviteRecord,
)
from ..storage.base import BaseStorage

from . import PROG

Expand All @@ -35,6 +40,14 @@ async def provision(settings: dict):

root_profile, public_did = await wallet_config(context, provision=True)

# store mediator invite url if provided
mediation_invite = settings.get("mediation.invite", None)
if mediation_invite:
async with root_profile.session() as session:
await MediationInviteStore(session.context.inject(BaseStorage)).store(
MediationInviteRecord.unused(mediation_invite)
)

if await ledger_config(root_profile, public_did and public_did.did, True):
print("Ledger configured")
else:
Expand Down
19 changes: 18 additions & 1 deletion aries_cloudagent/commands/tests/test_provision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from ...config.base import ConfigError
from ...config.error import ArgsParseError
from ...core.profile import Profile
from .. import provision as test_module
from ...protocols.coordinate_mediation.mediation_invite_store import (
MediationInviteRecord,
)


class TestProvision(AsyncTestCase):
Expand Down Expand Up @@ -68,3 +70,18 @@ def test_main(self):
) as mock_execute:
test_module.main()
mock_execute.assert_called_once

async def test_provision_should_store_provided_mediation_invite(self):
# given
mediation_invite = "test-invite"

with async_mock.patch.object(
test_module.MediationInviteStore, "store"
) as invite_store:
# when
await test_module.provision({"mediation.invite": mediation_invite})

# then
invite_store.assert_called_with(
MediationInviteRecord(mediation_invite, False)
)
61 changes: 41 additions & 20 deletions aries_cloudagent/config/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,25 +1143,18 @@ def get_settings(self, args: Namespace):
return settings


@group(CAT_START)
class MediationGroup(ArgumentGroup):
"""Mediation settings."""
@group(CAT_START, CAT_PROVISION)
class MediationInviteGroup(ArgumentGroup):
"""
Mediation invitation settings.

GROUP_NAME = "Mediation"
These can be provided at provision- and start-time.
"""

GROUP_NAME = "Mediation invitation"

def add_arguments(self, parser: ArgumentParser):
"""Add mediation command line arguments to the parser."""
parser.add_argument(
"--open-mediation",
action="store_true",
env_var="ACAPY_MEDIATION_OPEN",
help=(
"Enables didcomm mediation. After establishing a connection, "
"if enabled, an agent may request message mediation, which will "
"allow the mediator to forward messages on behalf of the recipient. "
"See aries-rfc:0211."
),
)
"""Add mediation invitation command line arguments to the parser."""
parser.add_argument(
"--mediator-invitation",
type=str,
Expand All @@ -1182,6 +1175,38 @@ def add_arguments(self, parser: ArgumentParser):
"Default: false."
),
)

def get_settings(self, args: Namespace):
"""Extract mediation invitation settings."""
settings = {}
if args.mediator_invitation:
settings["mediation.invite"] = args.mediator_invitation
if args.mediator_connections_invite:
settings["mediation.connections_invite"] = True

return settings


@group(CAT_START)
class MediationGroup(ArgumentGroup):
"""Mediation settings."""

GROUP_NAME = "Mediation"

def add_arguments(self, parser: ArgumentParser):
"""Add mediation command line arguments to the parser."""
parser.add_argument(
"--open-mediation",
action="store_true",
env_var="ACAPY_MEDIATION_OPEN",
help=(
"Enables didcomm mediation. After establishing a connection, "
"if enabled, an agent may request message mediation, which will "
"allow the mediator to forward messages on behalf of the recipient. "
"See aries-rfc:0211."
),
)

parser.add_argument(
"--default-mediator-id",
type=str,
Expand All @@ -1201,14 +1226,10 @@ def get_settings(self, args: Namespace):
settings = {}
if args.open_mediation:
settings["mediation.open"] = True
if args.mediator_invitation:
settings["mediation.invite"] = args.mediator_invitation
if args.default_mediator_id:
settings["mediation.default_id"] = args.default_mediator_id
if args.clear_default_mediator:
settings["mediation.clear"] = True
if args.mediator_connections_invite:
settings["mediation.connections_invite"] = True

if args.clear_default_mediator and args.default_mediator_id:
raise ArgsParseError(
Expand Down
81 changes: 53 additions & 28 deletions aries_cloudagent/core/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
ConnectionInvitation,
)
from ..protocols.coordinate_mediation.v1_0.manager import MediationManager
from ..protocols.coordinate_mediation.mediation_invite_store import MediationInviteStore
from ..protocols.out_of_band.v1_0.manager import OutOfBandManager
from ..protocols.out_of_band.v1_0.messages.invitation import HSProto, InvitationMessage
from ..storage.base import BaseStorage
from ..transport.inbound.manager import InboundTransportManager
from ..transport.inbound.message import InboundMessage
from ..transport.outbound.base import OutboundDeliveryError
Expand Down Expand Up @@ -303,41 +305,64 @@ async def start(self) -> None:
except Exception:
LOGGER.exception("Error creating invitation")

# Accept mediation invitation if specified
mediation_invitation: str = context.settings.get("mediation.invite")
if mediation_invitation:
# mediation connection establishment
provided_invite: str = context.settings.get("mediation.invite")
async with self.root_profile.session() as session:
try:
mediation_connections_invite = context.settings.get(
"mediation.connections_invite", False
)
invitation_handler = (
ConnectionInvitation
if mediation_connections_invite
else InvitationMessage
invite_store = MediationInviteStore(session.context.inject(BaseStorage))
mediation_invite_record = (
await invite_store.get_mediation_invite_record(provided_invite)
)
except Exception:
LOGGER.exception("Error retrieving mediator invitation")
mediation_invite_record = None

async with self.root_profile.session() as session:
mgr = (
ConnectionManager(session)
# Accept mediation invitation if one was specified or stored
if mediation_invite_record is not None:
try:
mediation_connections_invite = context.settings.get(
"mediation.connections_invite", False
)
invitation_handler = (
ConnectionInvitation
if mediation_connections_invite
else OutOfBandManager(session)
else InvitationMessage
)

conn_record = await mgr.receive_invitation(
invitation=invitation_handler.from_url(mediation_invitation),
auto_accept=True,
)
if not mediation_invite_record.used:
# clear previous mediator configuration before establishing a
# new one
await MediationManager(session.profile).clear_default_mediator()

await conn_record.metadata_set(
session, MediationManager.SEND_REQ_AFTER_CONNECTION, True
)
await conn_record.metadata_set(
session, MediationManager.SET_TO_DEFAULT_ON_GRANTED, True
)
print("Attempting to connect to mediator...")
del mgr
except Exception:
LOGGER.exception("Error accepting mediation invitation")
mgr = (
ConnectionManager(session)
if mediation_connections_invite
else OutOfBandManager(session)
)

conn_record = await mgr.receive_invitation(
invitation=invitation_handler.from_url(
mediation_invite_record.invite
),
auto_accept=True,
)
await (
MediationInviteStore(
session.context.inject(BaseStorage)
).mark_default_invite_as_used()
)

await conn_record.metadata_set(
session, MediationManager.SEND_REQ_AFTER_CONNECTION, True
)
await conn_record.metadata_set(
session, MediationManager.SET_TO_DEFAULT_ON_GRANTED, True
)

print("Attempting to connect to mediator...")
del mgr
except Exception:
LOGGER.exception("Error accepting mediation invitation")

async def stop(self, timeout=1.0):
"""Stop the agent."""
Expand Down
Loading