Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improvements for connections v1 testing #64

Merged
merged 4 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import asyncio
from contextlib import AsyncExitStack
from dataclasses import asdict, is_dataclass
import dataclasses
import logging
from json import dumps
from types import TracebackType
from typing import (
Any,
ClassVar,
Mapping,
Optional,
Protocol,
Expand Down Expand Up @@ -49,6 +51,8 @@ def deserialize(cls: Type[T], value: Mapping[str, Any]) -> T:
class Dataclass(Protocol):
"""Empty protocol for dataclass type hinting."""

__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]


Serializable = Union[Mapping[str, Any], Serde, BaseModel, Dataclass, None]

Expand Down Expand Up @@ -98,7 +102,7 @@ def _deserialize(
if issubclass(as_type, Serde):
return as_type.deserialize(value)
if is_dataclass(as_type):
return as_type(**value)
return cast(T, as_type(**value))
raise TypeError(f"Could not deserialize value into type {as_type.__name__}")


Expand Down Expand Up @@ -514,8 +518,9 @@ async def record(
)
except asyncio.TimeoutError:
raise ControllerError(
f"Record with topic {topic} not received before timeout"
)
f"Record from {self.label} with topic {topic} not received "
"before timeout"
) from None
return _deserialize(event.payload, record_type)

@overload
Expand All @@ -528,14 +533,6 @@ async def record_with_values(
) -> T:
...

@overload
async def record_with_values(
self,
topic: str,
**values,
) -> Mapping[str, Any]:
...

@overload
async def record_with_values(
self,
Expand Down Expand Up @@ -565,7 +562,7 @@ async def record_with_values(
)
except asyncio.TimeoutError:
raise ControllerError(
f"Record with topic {topic} and values {values} "
f"Record from {self.label} with topic {topic} and values\n\t{values}\n"
"not received before timeout"
)
) from None
return _deserialize(event.payload, record_type)
167 changes: 100 additions & 67 deletions controller/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,39 @@ def _make_params(**kwargs) -> Mapping[str, Any]:
}


async def connection(inviter: Controller, invitee: Controller):
"""Connect two agents."""
async def connection_invitation(
inviter: Controller,
*,
use_public_did: bool = False,
multi_use: Optional[bool] = None,
):
"""Create a connection invitation.

This will always create an invite with auto_accept set to false to simplify
the connection function below.
"""
invitation = await inviter.post(
"/connections/create-invitation", json={}, response=InvitationResult
"/connections/create-invitation",
json={},
params=_make_params(
auto_accept=False, multi_use=multi_use, public=use_public_did
),
response=InvitationResult,
)
return invitation


async def connection(
inviter: Controller,
invitee: Controller,
*,
invitation: Optional[InvitationResult] = None,
):
"""Connect two agents."""

if invitation is None:
invitation = await connection_invitation(inviter)

inviter_conn = await inviter.get(
f"/connections/{invitation.connection_id}",
response=ConnRecord,
Expand All @@ -103,10 +130,11 @@ async def connection(inviter: Controller, invitee: Controller):
f"/connections/{invitee_conn.connection_id}/accept-invitation",
)

await inviter.record_with_values(
inviter_conn = await inviter.record_with_values(
topic="connections",
connection_id=inviter_conn.connection_id,
invitation_key=inviter_conn.invitation_key,
rfc23_state="request-received",
record_type=ConnRecord,
)

inviter_conn = await inviter.post(
Expand Down Expand Up @@ -140,33 +168,44 @@ async def connection(inviter: Controller, invitee: Controller):
return inviter_conn, invitee_conn


async def oob_invitation(
inviter: Controller,
*,
use_public_did: bool = False,
multi_use: Optional[bool] = None,
) -> InvitationMessage:
"""Create an OOB invitation.

This will always create an invite with auto_accept set to false to simplify
the didexchange function below.
"""
invite_record = await inviter.post(
"/out-of-band/create-invitation",
json=InvitationCreateRequest.parse_obj(
{
"handshake_protocols": ["https://didcomm.org/didexchange/1.0"],
"use_public_did": use_public_did,
}
),
params=_make_params(
auto_accept=False,
multi_use=multi_use,
),
response=InvitationRecord,
)
return invite_record.invitation


async def didexchange(
inviter: Controller,
invitee: Controller,
*,
invite: Optional[InvitationMessage] = None,
use_public_did: bool = False,
auto_accept: Optional[bool] = None,
multi_use: Optional[bool] = None,
use_existing_connection: bool = False,
):
"""Connect two agents using did exchange protocol."""
if not invite:
invite_record = await inviter.post(
"/out-of-band/create-invitation",
json=InvitationCreateRequest.parse_obj(
{
"handshake_protocols": ["https://didcomm.org/didexchange/1.0"],
"use_public_did": use_public_did,
}
),
params=_make_params(
auto_accept=auto_accept,
multi_use=multi_use,
),
response=InvitationRecord,
)
invite = invite_record.invitation
invite = await oob_invitation(inviter)

inviter_conn = (
await inviter.get(
Expand Down Expand Up @@ -201,51 +240,45 @@ async def didexchange(
)
return inviter_conn, invitee_conn

if not auto_accept:
invitee_conn = await invitee.post(
f"/didexchange/{invitee_oob_record.connection_id}/accept-invitation",
response=ConnRecord,
)
inviter_oob_record = await inviter.record_with_values(
topic="out_of_band",
record_type=OobRecord,
connection_id=inviter_conn.connection_id,
state="done",
)
# Overwrite multiuse invitation connection with actual connection
inviter_conn = await inviter.record_with_values(
topic="connections",
record_type=ConnRecord,
rfc23_state="request-received",
invitation_key=inviter_oob_record.our_recipient_key,
)
inviter_conn = await inviter.post(
f"/didexchange/{inviter_conn.connection_id}/accept-request",
response=ConnRecord,
)
invitee_conn = await invitee.post(
f"/didexchange/{invitee_oob_record.connection_id}/accept-invitation",
response=ConnRecord,
)
inviter_oob_record = await inviter.record_with_values(
topic="out_of_band",
record_type=OobRecord,
connection_id=inviter_conn.connection_id,
state="done",
)
# Overwrite multiuse invitation connection with actual connection
inviter_conn = await inviter.record_with_values(
topic="connections",
record_type=ConnRecord,
rfc23_state="request-received",
invitation_key=inviter_oob_record.our_recipient_key,
)
inviter_conn = await inviter.post(
f"/didexchange/{inviter_conn.connection_id}/accept-request",
response=ConnRecord,
)

await invitee.record_with_values(
topic="connections",
connection_id=invitee_conn.connection_id,
rfc23_state="response-received",
)
invitee_conn = await invitee.record_with_values(
topic="connections",
connection_id=invitee_conn.connection_id,
rfc23_state="completed",
record_type=ConnRecord,
)
inviter_conn = await inviter.record_with_values(
topic="connections",
connection_id=inviter_conn.connection_id,
rfc23_state="completed",
record_type=ConnRecord,
)
else:
invitee_conn = await invitee.get(
f"/connections/{invitee_oob_record.connection_id}",
response=ConnRecord,
)
await invitee.record_with_values(
topic="connections",
connection_id=invitee_conn.connection_id,
rfc23_state="response-received",
)
invitee_conn = await invitee.record_with_values(
topic="connections",
connection_id=invitee_conn.connection_id,
rfc23_state="completed",
record_type=ConnRecord,
)
inviter_conn = await inviter.record_with_values(
topic="connections",
connection_id=inviter_conn.connection_id,
rfc23_state="completed",
record_type=ConnRecord,
)

return inviter_conn, invitee_conn

Expand Down
8 changes: 6 additions & 2 deletions examples/simple/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
version: "3"
services:
alice:
image: ghcr.io/hyperledger/aries-cloudagent-python:py3.9-0.10.1
image: ghcr.io/hyperledger/aries-cloudagent-python:py3.9-0.11.0
ports:
- "3001:3001"
environment:
RUST_LOG: 'aries-askar::log::target=error'
command: >
start
--label Alice
Expand Down Expand Up @@ -34,9 +36,11 @@ services:
condition: service_started

bob:
image: ghcr.io/hyperledger/aries-cloudagent-python:py3.9-0.10.1
image: ghcr.io/hyperledger/aries-cloudagent-python:py3.9-0.11.0
ports:
- "3002:3001"
environment:
RUST_LOG: 'aries-askar::log::target=error'
command: >
start
--label Bob
Expand Down
3 changes: 2 additions & 1 deletion examples/simple/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from controller import Controller
from controller.logging import logging_to_stdout
from controller.protocols import didexchange
from controller.protocols import connection, didexchange

ALICE = getenv("ALICE", "http://alice:3001")
BOB = getenv("BOB", "http://bob:3001")
Expand All @@ -17,6 +17,7 @@
async def main():
"""Test Controller protocols."""
async with Controller(base_url=ALICE) as alice, Controller(base_url=BOB) as bob:
await connection(alice, bob)
await didexchange(alice, bob)


Expand Down
7 changes: 6 additions & 1 deletion tests/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
didexchange,
indy_anoncreds_publish_revocation,
indy_anoncreds_revoke,
oob_invitation,
)


Expand All @@ -40,7 +41,11 @@ async def test_did_exchange(did_exchange: Tuple[ConnRecord, ConnRecord]):
@pytest.mark.asyncio
async def test_did_exchange_with_multiuse(alice, bob):
"""Testing that dids are exchanged successfully."""
alice_conn, bob_conn = await didexchange(alice, bob, multi_use=True)
invite = await oob_invitation(alice, multi_use=True)
alice_conn, bob_conn = await didexchange(alice, bob, invite=invite)
assert alice_conn.rfc23_state == "completed"
assert bob_conn.rfc23_state == "completed"
alice_conn, bob_conn = await didexchange(alice, bob, invite=invite)
assert alice_conn.rfc23_state == "completed"
assert bob_conn.rfc23_state == "completed"

Expand Down