Skip to content

Commit

Permalink
allow recipients with different key types in ECDH-ES
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Whitehead <cywolf@gmail.com>
  • Loading branch information
andrewwhitehead committed Aug 18, 2021
1 parent a06d3a6 commit 8022586
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 55 deletions.
66 changes: 49 additions & 17 deletions aries_cloudagent/askar/didcomm/tests/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from aries_askar import AskarError, Key, KeyAlg, Session

from ....config.injection_context import InjectionContext
from ....utils.jwe import b64url, JweEnvelope
from ....utils.jwe import JweRecipient, b64url, JweEnvelope

from ...profile import AskarProfileManager
from .. import v2 as test_module


ALICE_KID = "did:example:alice#key-1"
BOB_KID = "did:example:bob#key-1"
CAROL_KID = "did:example:carol#key-2"
MESSAGE = b"Expecto patronum"


Expand Down Expand Up @@ -41,8 +42,12 @@ async def test_es_round_trip(self, session: Session):
alg = KeyAlg.X25519
bob_sk = Key.generate(alg)
bob_pk = Key.from_jwk(bob_sk.get_jwk_public())
carol_sk = Key.generate(KeyAlg.P256) # testing mixed recipient key types
carol_pk = Key.from_jwk(carol_sk.get_jwk_public())

enc_message = test_module.ecdh_es_encrypt({BOB_KID: bob_pk}, MESSAGE)
enc_message = test_module.ecdh_es_encrypt(
{BOB_KID: bob_pk, CAROL_KID: carol_pk}, MESSAGE
)

# receiver must have the private keypair accessible
await session.insert_key("my_sk", bob_sk, tags={"kid": BOB_KID})
Expand All @@ -65,13 +70,6 @@ async def test_es_encrypt_x(self, session: Session):
):
_ = test_module.ecdh_es_encrypt({}, MESSAGE)

alt_sk = Key.generate(KeyAlg.P256)
alt_pk = Key.from_jwk(alt_sk.get_jwk_public())
with pytest.raises(
test_module.DidcommEnvelopeError, match="key types must be consistent"
):
_ = test_module.ecdh_es_encrypt({BOB_KID: bob_pk, "alt": alt_pk}, MESSAGE)

with async_mock.patch(
"aries_askar.Key.generate",
async_mock.MagicMock(side_effect=AskarError(99, "")),
Expand All @@ -96,34 +94,54 @@ async def test_es_encrypt_x(self, session: Session):
async def test_es_decrypt_x(self):
alg = KeyAlg.X25519
bob_sk = Key.generate(alg)
bob_pk = Key.from_jwk(bob_sk.get_jwk_public())

message_unknown_alg = JweEnvelope(
protected={"alg": "NOT-SUPPORTED"},
)
message_unknown_alg.add_recipient(
JweRecipient(encrypted_key=b"0000", header={"kid": BOB_KID})
)
with pytest.raises(
test_module.DidcommEnvelopeError,
match="Unsupported ECDH-ES algorithm",
):
_ = test_module.ecdh_es_decrypt(message_unknown_alg, bob_sk, b"0000")
_ = test_module.ecdh_es_decrypt(
message_unknown_alg,
BOB_KID,
bob_sk,
)

message_unknown_enc = JweEnvelope(
protected={"alg": "ECDH-ES+A128KW", "enc": "UNKNOWN"},
)
message_unknown_enc.add_recipient(
JweRecipient(encrypted_key=b"0000", header={"kid": BOB_KID})
)
with pytest.raises(
test_module.DidcommEnvelopeError,
match="Unsupported ECDH-ES content encryption",
):
_ = test_module.ecdh_es_decrypt(message_unknown_enc, bob_sk, b"0000")
_ = test_module.ecdh_es_decrypt(
message_unknown_enc,
BOB_KID,
bob_sk,
)

message_invalid_epk = JweEnvelope(
protected={"alg": "ECDH-ES+A128KW", "enc": "A256GCM", "epk": {}},
)
message_invalid_epk.add_recipient(
JweRecipient(encrypted_key=b"0000", header={"kid": BOB_KID})
)
with pytest.raises(
test_module.DidcommEnvelopeError,
match="Error loading ephemeral key",
):
_ = test_module.ecdh_es_decrypt(message_invalid_epk, bob_sk, b"0000")
_ = test_module.ecdh_es_decrypt(
message_invalid_epk,
BOB_KID,
bob_sk,
)

@pytest.mark.asyncio
async def test_1pu_round_trip(self, session: Session):
Expand Down Expand Up @@ -200,39 +218,53 @@ async def test_1pu_decrypt_x(self):
alice_sk = Key.generate(alg)
alice_pk = Key.from_jwk(alice_sk.get_jwk_public())
bob_sk = Key.generate(alg)
bob_pk = Key.from_jwk(bob_sk.get_jwk_public())

message_unknown_alg = JweEnvelope(
protected={"alg": "NOT-SUPPORTED"},
)
message_unknown_alg.add_recipient(
JweRecipient(encrypted_key=b"0000", header={"kid": BOB_KID})
)
with pytest.raises(
test_module.DidcommEnvelopeError,
match="Unsupported ECDH-1PU algorithm",
):
_ = test_module.ecdh_1pu_decrypt(
message_unknown_alg, alice_pk, bob_sk, b"0000"
message_unknown_alg,
BOB_KID,
bob_sk,
alice_pk,
)

message_unknown_enc = JweEnvelope(
protected={"alg": "ECDH-1PU+A128KW", "enc": "UNKNOWN"},
)
message_unknown_enc.add_recipient(
JweRecipient(encrypted_key=b"0000", header={"kid": BOB_KID})
)
with pytest.raises(
test_module.DidcommEnvelopeError,
match="Unsupported ECDH-1PU content encryption",
):
_ = test_module.ecdh_1pu_decrypt(
message_unknown_enc, alice_pk, bob_sk, b"0000"
message_unknown_enc, BOB_KID, bob_sk, alice_pk
)

message_invalid_epk = JweEnvelope(
protected={"alg": "ECDH-1PU+A128KW", "enc": "A256CBC-HS512", "epk": {}},
)
message_invalid_epk.add_recipient(
JweRecipient(encrypted_key=b"0000", header={"kid": BOB_KID})
)
with pytest.raises(
test_module.DidcommEnvelopeError,
match="Error loading ephemeral key",
):
_ = test_module.ecdh_1pu_decrypt(
message_invalid_epk, alice_pk, bob_sk, b"0000"
message_invalid_epk,
BOB_KID,
bob_sk,
alice_pk,
)

@pytest.mark.asyncio
Expand Down
78 changes: 40 additions & 38 deletions aries_cloudagent/askar/didcomm/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,31 @@ def ecdh_es_encrypt(to_verkeys: Mapping[str, Key], message: bytes) -> bytes:
if not to_verkeys:
raise DidcommEnvelopeError("No message recipients")

agree_alg = None
for (kid, recip_key) in to_verkeys.items():
if agree_alg:
if agree_alg != recip_key.algorithm:
raise DidcommEnvelopeError("Recipient key types must be consistent")
else:
agree_alg = recip_key.algorithm

try:
cek = Key.generate(enc_alg)
except AskarError:
raise DidcommEnvelopeError("Error creating content encryption key")

try:
epk = Key.generate(agree_alg)
except AskarError:
raise DidcommEnvelopeError("Error creating ephemeral key")

for (kid, recip_key) in to_verkeys.items():
try:
epk = Key.generate(recip_key.algorithm, ephemeral=True)
except AskarError:
raise DidcommEnvelopeError("Error creating ephemeral key")
enc_key = ecdh.EcdhEs(alg_id, None, None).sender_wrap_key(
wrap_alg, epk, recip_key, cek
)
wrapper.add_recipient(
JweRecipient(encrypted_key=enc_key.ciphertext, header={"kid": kid})
JweRecipient(
encrypted_key=enc_key.ciphertext,
header={"kid": kid, "epk": epk.get_jwk_public()},
)
)

wrapper.set_protected(
OrderedDict(
[
("alg", alg_id),
("enc", enc_id),
("epk", json.loads(epk.get_jwk_public())),
]
)
)
Expand All @@ -73,7 +66,9 @@ def ecdh_es_encrypt(to_verkeys: Mapping[str, Key], message: bytes) -> bytes:


def ecdh_es_decrypt(
wrapper: JweEnvelope, recip_key: Key, encrypted_key: bytes
wrapper: JweEnvelope,
recip_kid: str,
recip_key: Key,
) -> bytes:
"""Decode a message with DIDComm v2 anonymous encryption."""

Expand All @@ -83,25 +78,29 @@ def ecdh_es_decrypt(
else:
raise DidcommEnvelopeError(f"Unsupported ECDH-ES algorithm: {alg_id}")

enc_alg = wrapper.protected.get("enc")
recip = wrapper.get_recipient(recip_kid)
if not recip:
raise DidcommEnvelopeError(f"Recipient header not found: {recip_kid}")

enc_alg = recip.header.get("enc")
if enc_alg not in ("A128GCM", "A256GCM", "A128CBC-HS256", "A256CBC-HS512", "XC20P"):
raise DidcommEnvelopeError(f"Unsupported ECDH-ES content encryption: {enc_alg}")

try:
epk = Key.from_jwk(wrapper.protected.get("epk"))
epk = Key.from_jwk(recip.header.get("epk"))
except AskarError:
raise DidcommEnvelopeError("Error loading ephemeral key")

apu = wrapper.protected.get("apu")
apv = wrapper.protected.get("apv")
apu = recip.header.get("apu")
apv = recip.header.get("apv")

try:
cek = ecdh.EcdhEs(alg_id, apu, apv).receiver_unwrap_key(
wrap_alg,
enc_alg,
epk,
recip_key,
encrypted_key,
recip.encrypted_key,
)
except AskarError:
raise DidcommEnvelopeError("Error decrypting content encryption key")
Expand Down Expand Up @@ -140,7 +139,7 @@ def ecdh_1pu_encrypt(
raise DidcommEnvelopeError("Error creating content encryption key")

try:
epk = Key.generate(agree_alg)
epk = Key.generate(agree_alg, ephemeral=True)
except AskarError:
raise DidcommEnvelopeError("Error creating ephemeral key")

Expand Down Expand Up @@ -186,7 +185,10 @@ def ecdh_1pu_encrypt(


def ecdh_1pu_decrypt(
wrapper: JweEnvelope, sender_key: Key, recip_key: Key, encrypted_key: bytes
wrapper: JweEnvelope,
recip_kid: str,
recip_key: Key,
sender_key: Key,
) -> Tuple[str, str, str]:
"""Decode a message with DIDComm v2 authenticated encryption."""

Expand All @@ -202,6 +204,10 @@ def ecdh_1pu_decrypt(
f"Unsupported ECDH-1PU content encryption: {enc_alg}"
)

recip = wrapper.get_recipient(recip_kid)
if not recip:
raise DidcommEnvelopeError(f"Recipient header not found: {recip_kid}")

try:
epk = Key.from_jwk(wrapper.protected.get("epk"))
except AskarError:
Expand All @@ -217,7 +223,7 @@ def ecdh_1pu_decrypt(
epk,
sender_key,
recip_key,
encrypted_key,
recip.encrypted_key,
cc_tag=wrapper.tag,
)
except AskarError:
Expand Down Expand Up @@ -254,18 +260,14 @@ async def unpack_message(
sender_kid = None
recip_key = None
recip_kid = None
encrypted_key = None
for recip in wrapper.recipients:
kid = recip.header.get("kid")
if kid:
recip_key_entry = next(
await session.fetch_all_keys(tag_filter={"kid": kid}), None
)
if recip_key_entry:
recip_kid = kid
recip_key = recip_key_entry.key
encrypted_key = recip.encrypted_key
break
for kid in wrapper.recipient_key_ids:
recip_key_entry = next(
await session.fetch_all_keys(tag_filter={"kid": kid}), None
)
if recip_key_entry:
recip_kid = kid
recip_key = recip_key_entry.key
break

if not recip_key:
raise DidcommEnvelopeError("No recognized recipient key")
Expand All @@ -292,8 +294,8 @@ async def unpack_message(
if not sender_key_entry:
raise DidcommEnvelopeError("Sender public key not found")
sender_key = sender_key_entry.key
plaintext = ecdh_1pu_decrypt(wrapper, sender_key, recip_key, encrypted_key)
plaintext = ecdh_1pu_decrypt(wrapper, recip_kid, recip_key, sender_key)
else:
plaintext = ecdh_es_decrypt(wrapper, recip_key, encrypted_key)
plaintext = ecdh_es_decrypt(wrapper, recip_kid, recip_key)

return plaintext, recip_kid, sender_kid
16 changes: 16 additions & 0 deletions aries_cloudagent/utils/jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,22 @@ def recipients_json(self) -> List[Dict[str, Any]]:
"""Encode the current recipients for JSON."""
return [recip.serialize() for recip in self._recipients]

@property
def recipient_key_ids(self) -> Iterable[JweRecipient]:
"""Accessor for an iterator over the JWE recipient key identifiers."""
for recip in self._recipients:
if recip.header and "kid" in recip.header:
yield recip.header["kid"]

def get_recipient(self, kid: str) -> JweRecipient:
"""Find a recipient by key ID."""
for recip in self._recipients:
if recip.header and recip.header.get("kid") == kid:
header = self.protected.copy()
header.update(self.unprotected)
header.update(recip.header)
return JweRecipient(encrypted_key=recip.encrypted_key, header=header)

@property
def combined_aad(self) -> bytes:
"""Accessor for the additional authenticated data."""
Expand Down

0 comments on commit 8022586

Please sign in to comment.