diff --git a/ddht/abc.py b/ddht/abc.py index 66216d2e..6722422f 100644 --- a/ddht/abc.py +++ b/ddht/abc.py @@ -22,7 +22,7 @@ from ddht.base_message import BaseMessage from ddht.boot_info import BootInfo -from ddht.typing import JSON, IDNonce, SessionKeys +from ddht.typing import JSON, SessionKeys TAddress = TypeVar("TAddress", bound="AddressAPI") @@ -177,8 +177,12 @@ def iter_all_random(self) -> Iterator[NodeID]: ... -class HandshakeSchemeAPI(ABC): +TSignatureInputs = TypeVar("TSignatureInputs") + + +class HandshakeSchemeAPI(ABC, Generic[TSignatureInputs]): identity_scheme: Type[IdentitySchemeAPI] + signature_inputs_cls: Type[TSignatureInputs] # # Handshake @@ -204,7 +208,7 @@ def compute_session_keys( remote_public_key: bytes, local_node_id: NodeID, remote_node_id: NodeID, - id_nonce: IDNonce, + salt: bytes, is_locally_initiated: bool, ) -> SessionKeys: """Compute the symmetric session keys.""" @@ -213,7 +217,7 @@ def compute_session_keys( @classmethod @abstractmethod def create_id_nonce_signature( - cls, *, id_nonce: IDNonce, ephemeral_public_key: bytes, private_key: bytes + cls, *, signature_inputs: TSignatureInputs, private_key: bytes, ) -> bytes: """Sign an id nonce received during handshake.""" ... @@ -221,12 +225,7 @@ def create_id_nonce_signature( @classmethod @abstractmethod def validate_id_nonce_signature( - cls, - *, - id_nonce: IDNonce, - ephemeral_public_key: bytes, - signature: bytes, - public_key: bytes, + cls, *, signature_inputs: TSignatureInputs, signature: bytes, public_key: bytes, ) -> None: """Validate the id nonce signature received from a peer.""" ... @@ -235,7 +234,7 @@ def validate_id_nonce_signature( # https://github.com/python/mypy/issues/5264#issuecomment-399407428 if TYPE_CHECKING: HandshakeSchemeRegistryBaseType = UserDict[ - Type[IdentitySchemeAPI], Type[HandshakeSchemeAPI] + Type[IdentitySchemeAPI], Type[HandshakeSchemeAPI[Any]] ] else: HandshakeSchemeRegistryBaseType = UserDict @@ -244,8 +243,8 @@ def validate_id_nonce_signature( class HandshakeSchemeRegistryAPI(HandshakeSchemeRegistryBaseType): @abstractmethod def register( - self, handshake_scheme_class: Type[HandshakeSchemeAPI] - ) -> Type[HandshakeSchemeAPI]: + self, handshake_scheme_class: Type[HandshakeSchemeAPI[TSignatureInputs]] + ) -> Type[HandshakeSchemeAPI[TSignatureInputs]]: ... diff --git a/ddht/constants.py b/ddht/constants.py index 87702590..16cf07b8 100644 --- a/ddht/constants.py +++ b/ddht/constants.py @@ -17,7 +17,6 @@ NEIGHBOURS_RESPONSE_ITEMS = 16 AES128_KEY_SIZE = 16 # size of an AES218 key HKDF_INFO = b"discovery v5 key agreement" -ID_NONCE_SIGNATURE_PREFIX = b"discovery-id-nonce" ENR_REPR_PREFIX = "enr:" # prefix used when printing an ENR MAX_ENR_SIZE = 300 # maximum allowed size of an ENR IP_V4_ADDRESS_ENR_KEY = b"ip" diff --git a/ddht/encryption.py b/ddht/encryption.py index 1ba6e035..02a44981 100644 --- a/ddht/encryption.py +++ b/ddht/encryption.py @@ -24,7 +24,7 @@ def validate_nonce(nonce: bytes) -> None: def aesgcm_encrypt( - key: AES128Key, nonce: Nonce, plain_text: bytes, authenticated_data: bytes + key: AES128Key, nonce: Nonce, plain_text: bytes, authenticated_data: bytes, ) -> bytes: validate_aes128_key(key) validate_nonce(nonce) @@ -35,7 +35,7 @@ def aesgcm_encrypt( def aesgcm_decrypt( - key: AES128Key, nonce: Nonce, cipher_text: bytes, authenticated_data: bytes + key: AES128Key, nonce: Nonce, cipher_text: bytes, authenticated_data: bytes, ) -> bytes: validate_aes128_key(key) validate_nonce(nonce) @@ -44,7 +44,7 @@ def aesgcm_decrypt( try: plain_text = aesgcm.decrypt(nonce, cipher_text, authenticated_data) except InvalidTag as error: - raise DecryptionError() from error + raise DecryptionError(str(error)) from error else: return plain_text diff --git a/ddht/handshake_schemes.py b/ddht/handshake_schemes.py index dc5ac046..ad01b013 100644 --- a/ddht/handshake_schemes.py +++ b/ddht/handshake_schemes.py @@ -1,4 +1,3 @@ -from hashlib import sha256 import secrets from typing import Tuple, Type @@ -13,15 +12,15 @@ from eth_typing import NodeID from eth_utils import ValidationError, encode_hex -from ddht.abc import HandshakeSchemeAPI, HandshakeSchemeRegistryAPI -from ddht.constants import AES128_KEY_SIZE, HKDF_INFO, ID_NONCE_SIGNATURE_PREFIX -from ddht.typing import AES128Key, IDNonce, SessionKeys +from ddht.abc import HandshakeSchemeAPI, HandshakeSchemeRegistryAPI, TSignatureInputs +from ddht.constants import AES128_KEY_SIZE, HKDF_INFO +from ddht.typing import AES128Key, SessionKeys class HandshakeSchemeRegistry(HandshakeSchemeRegistryAPI): def register( - self, handshake_scheme_class: Type[HandshakeSchemeAPI] - ) -> Type[HandshakeSchemeAPI]: + self, handshake_scheme_class: Type[HandshakeSchemeAPI[TSignatureInputs]] + ) -> Type[HandshakeSchemeAPI[TSignatureInputs]]: """Class decorator to register handshake schemes.""" is_missing_identity_scheme = ( not hasattr(handshake_scheme_class, "identity_scheme") @@ -42,9 +41,6 @@ def register( return handshake_scheme_class -default_handshake_scheme_registry = HandshakeSchemeRegistry() - - def ecdh_agree(private_key: bytes, public_key: bytes) -> bytes: """ Perform the ECDH key agreement. @@ -66,17 +62,14 @@ def ecdh_agree(private_key: bytes, public_key: bytes) -> bytes: def hkdf_expand_and_extract( - secret: bytes, - initiator_node_id: NodeID, - recipient_node_id: NodeID, - id_nonce: IDNonce, + secret: bytes, initiator_node_id: NodeID, recipient_node_id: NodeID, salt: bytes, ) -> Tuple[bytes, bytes, bytes]: info = b"".join((HKDF_INFO, initiator_node_id, recipient_node_id)) hkdf = HKDF( algorithm=SHA256(), length=3 * AES128_KEY_SIZE, - salt=id_nonce, + salt=salt, info=info, backend=cryptography_default_backend(), ) @@ -94,8 +87,7 @@ def hkdf_expand_and_extract( return initiator_key, recipient_key, auth_response_key -@default_handshake_scheme_registry.register -class V4HandshakeScheme(HandshakeSchemeAPI): +class BaseV4HandshakeScheme(HandshakeSchemeAPI[TSignatureInputs]): identity_scheme = V4IdentityScheme # @@ -119,7 +111,7 @@ def compute_session_keys( remote_public_key: bytes, local_node_id: NodeID, remote_node_id: NodeID, - id_nonce: IDNonce, + salt: bytes, is_locally_initiated: bool, ) -> SessionKeys: secret = ecdh_agree(local_private_key, remote_public_key) @@ -130,7 +122,7 @@ def compute_session_keys( initiator_node_id, recipient_node_id = remote_node_id, local_node_id initiator_key, recipient_key, auth_response_key = hkdf_expand_and_extract( - secret, initiator_node_id, recipient_node_id, id_nonce + secret, initiator_node_id, recipient_node_id, salt, ) if is_locally_initiated: @@ -144,33 +136,6 @@ def compute_session_keys( auth_response_key=AES128Key(auth_response_key), ) - @classmethod - def create_id_nonce_signature( - cls, *, id_nonce: IDNonce, ephemeral_public_key: bytes, private_key: bytes - ) -> bytes: - private_key_object = PrivateKey(private_key) - signature_input = cls.create_id_nonce_signature_input( - id_nonce=id_nonce, ephemeral_public_key=ephemeral_public_key - ) - signature = private_key_object.sign_msg_hash_non_recoverable(signature_input) - return bytes(signature) - - @classmethod - def validate_id_nonce_signature( - cls, - *, - id_nonce: IDNonce, - ephemeral_public_key: bytes, - signature: bytes, - public_key: bytes, - ) -> None: - signature_input = cls.create_id_nonce_signature_input( - id_nonce=id_nonce, ephemeral_public_key=ephemeral_public_key - ) - cls.identity_scheme.validate_signature( - message_hash=signature_input, signature=signature, public_key=public_key - ) - # # Helpers # @@ -210,10 +175,3 @@ def validate_signature( f"Signature {encode_hex(signature)} is not valid for message hash " f"{encode_hex(message_hash)} and public key {encode_hex(public_key)}" ) - - @classmethod - def create_id_nonce_signature_input( - cls, *, id_nonce: IDNonce, ephemeral_public_key: bytes - ) -> bytes: - preimage = b"".join((ID_NONCE_SIGNATURE_PREFIX, id_nonce, ephemeral_public_key)) - return sha256(preimage).digest() diff --git a/ddht/tools/driver/abc.py b/ddht/tools/driver/abc.py index 79c4a2a1..665079b5 100644 --- a/ddht/tools/driver/abc.py +++ b/ddht/tools/driver/abc.py @@ -68,11 +68,11 @@ async def next_message(self) -> AnyInboundMessage: ... @abstractmethod - async def send_ping(self, request_id: Optional[int] = None) -> PingMessage: + async def send_ping(self, request_id: Optional[bytes] = None) -> PingMessage: ... @abstractmethod - async def send_pong(self, request_id: Optional[int] = None) -> PongMessage: + async def send_pong(self, request_id: Optional[bytes] = None) -> PongMessage: ... diff --git a/ddht/tools/driver/session.py b/ddht/tools/driver/session.py index 36c5cff9..b78c3187 100644 --- a/ddht/tools/driver/session.py +++ b/ddht/tools/driver/session.py @@ -2,6 +2,7 @@ from typing import AsyncIterator, List, Optional from async_generator import asynccontextmanager +from eth_utils import int_to_big_endian import trio from ddht.base_message import AnyInboundMessage, AnyOutboundMessage, BaseMessage @@ -51,17 +52,17 @@ async def next_message(self) -> AnyInboundMessage: return await self.channels.inbound_message_receive_channel.receive() @no_hang - async def send_ping(self, request_id: Optional[int] = None) -> PingMessage: + async def send_ping(self, request_id: Optional[bytes] = None) -> PingMessage: if request_id is None: - request_id = secrets.randbits(32) + request_id = int_to_big_endian(secrets.randbits(32)) message = PingMessage(request_id, self.node.enr.sequence_number) await self.send_message(message) return message @no_hang - async def send_pong(self, request_id: Optional[int] = None) -> PongMessage: + async def send_pong(self, request_id: Optional[bytes] = None) -> PongMessage: if request_id is None: - request_id = secrets.randbits(32) + request_id = int_to_big_endian(secrets.randbits(32)) message = PongMessage( request_id, self.node.enr.sequence_number, @@ -133,22 +134,20 @@ async def transmit(self) -> AsyncIterator[None]: @no_hang async def send_packet(self, packet: AnyPacket) -> None: - if packet.header.source_node_id == self.initiator.node.node_id: - await self.recipient.session.handle_inbound_envelope( + if packet.dest_node_id == self.initiator.node.node_id: + await self.initiator.session.handle_inbound_envelope( InboundEnvelope( - packet=packet, sender_endpoint=self.initiator.node.endpoint, + packet=packet, sender_endpoint=self.recipient.node.endpoint, ) ) - elif packet.header.source_node_id == self.recipient.node.node_id: - await self.initiator.session.handle_inbound_envelope( + elif packet.dest_node_id == self.recipient.node.node_id: + await self.recipient.session.handle_inbound_envelope( InboundEnvelope( - packet=packet, sender_endpoint=self.recipient.node.endpoint, + packet=packet, sender_endpoint=self.initiator.node.endpoint, ) ) else: - raise Exception( - f"No matching node-id: {packet.header.source_node_id.hex()}" - ) + raise Exception(f"No matching node-id: {packet.dest_node_id.hex()}") @no_hang async def handshake(self) -> None: diff --git a/ddht/tools/factories/v5_1.py b/ddht/tools/factories/v5_1.py index 04252ae1..a1f62e4f 100644 --- a/ddht/tools/factories/v5_1.py +++ b/ddht/tools/factories/v5_1.py @@ -6,12 +6,14 @@ import trio from ddht.base_message import AnyInboundMessage, BaseMessage +from ddht.tools.factories.node_id import NodeIDFactory from ddht.typing import AES128Key, Nonce +from ddht.v5_1.constants import MESSAGE_PACKET_SIZE, PACKET_VERSION_1, PROTOCOL_ID from ddht.v5_1.envelope import OutboundEnvelope from ddht.v5_1.packets import ( - PROTOCOL_ID, HandshakeHeader, HandshakePacket, + Header, MessagePacket, Packet, TAuthData, @@ -24,7 +26,6 @@ class WhoAreYouPacketFactory(factory.Factory): # type: ignore class Meta: model = WhoAreYouPacket - request_nonce = factory.LazyFunction(lambda: secrets.token_bytes(12)) id_nonce = factory.LazyFunction(lambda: secrets.token_bytes(32)) enr_sequence_number = 0 @@ -33,7 +34,7 @@ class HandshakeHeaderFactory(factory.Factory): # type: ignore class Meta: model = HandshakeHeader - version = 1 + source_node_id = factory.SubFactory(NodeIDFactory) signature_size = 64 ephemeral_key_size = 33 @@ -48,36 +49,42 @@ class Meta: record = None +class HeaderFactory(factory.Factory): # type: ignore + protocol_id = PROTOCOL_ID + version = PACKET_VERSION_1 + flag = MessagePacket.flag + aes_gcm_nonce: Nonce = factory.LazyFunction(lambda: secrets.token_bytes(12)) + auth_data_size = MESSAGE_PACKET_SIZE + + class Meta: + model = Header + + class PacketFactory: @staticmethod def _prepare( *, - nonce: Optional[Nonce] = None, + aes_gcm_nonce: Optional[Nonce] = None, initiator_key: Optional[AES128Key] = None, message: BaseMessage, auth_data: TAuthData, - source_node_id: Optional[NodeID] = None, dest_node_id: Optional[NodeID] = None, protocol_id: bytes = PROTOCOL_ID ) -> Packet[TAuthData]: - if nonce is None: - nonce = Nonce(secrets.token_bytes(12)) + if aes_gcm_nonce is None: + aes_gcm_nonce = Nonce(secrets.token_bytes(12)) if initiator_key is None: initiator_key = AES128Key(secrets.token_bytes(16)) - if source_node_id is None: - source_node_id = NodeID(secrets.token_bytes(32)) - if dest_node_id is None: dest_node_id = NodeID(secrets.token_bytes(32)) return Packet.prepare( - nonce=nonce, + aes_gcm_nonce=aes_gcm_nonce, initiator_key=initiator_key, message=message, auth_data=auth_data, - source_node_id=source_node_id, dest_node_id=dest_node_id, protocol_id=protocol_id, ) @@ -86,27 +93,29 @@ def _prepare( def message( cls, *, - nonce: Optional[Nonce] = None, + aes_gcm_nonce: Optional[Nonce] = None, initiator_key: Optional[AES128Key] = None, message: Optional[BaseMessage] = None, source_node_id: Optional[NodeID] = None, dest_node_id: Optional[NodeID] = None, protocol_id: bytes = PROTOCOL_ID ) -> Packet[MessagePacket]: - if nonce is None: - nonce = Nonce(secrets.token_bytes(12)) + if aes_gcm_nonce is None: + aes_gcm_nonce = Nonce(secrets.token_bytes(12)) - auth_data = MessagePacket(nonce) + if source_node_id is None: + source_node_id = NodeIDFactory() + + auth_data = MessagePacket(source_node_id) if message is None: message = RandomMessage() return cls._prepare( - nonce=nonce, + aes_gcm_nonce=aes_gcm_nonce, initiator_key=initiator_key, message=message, auth_data=auth_data, - source_node_id=source_node_id, dest_node_id=dest_node_id, protocol_id=protocol_id, ) @@ -115,10 +124,9 @@ def message( def who_are_you( cls, *, - nonce: Optional[Nonce] = None, + aes_gcm_nonce: Optional[Nonce] = None, initiator_key: Optional[AES128Key] = None, message: Optional[BaseMessage] = None, - source_node_id: Optional[NodeID] = None, dest_node_id: Optional[NodeID] = None, protocol_id: bytes = PROTOCOL_ID ) -> Packet[MessagePacket]: @@ -126,11 +134,10 @@ def who_are_you( message = EmptyMessage() return cls._prepare( - nonce=nonce, + aes_gcm_nonce=aes_gcm_nonce, initiator_key=initiator_key, message=message, auth_data=auth_data, - source_node_id=source_node_id, dest_node_id=dest_node_id, protocol_id=protocol_id, ) @@ -139,24 +146,28 @@ def who_are_you( def handshake( cls, *, - nonce: Optional[Nonce] = None, + aes_gcm_nonce: Optional[Nonce] = None, initiator_key: Optional[AES128Key] = None, message: Optional[BaseMessage] = None, source_node_id: Optional[NodeID] = None, dest_node_id: Optional[NodeID] = None, protocol_id: bytes = PROTOCOL_ID ) -> Packet[MessagePacket]: - auth_data = HandshakePacketFactory() + if source_node_id is None: + source_node_id = NodeIDFactory() + + auth_data = HandshakePacketFactory( + auth_data_head__source_node_id=source_node_id + ) if message is None: message = EmptyMessage() return cls._prepare( - nonce=nonce, + aes_gcm_nonce=aes_gcm_nonce, initiator_key=initiator_key, message=message, auth_data=auth_data, - source_node_id=source_node_id, dest_node_id=dest_node_id, protocol_id=protocol_id, ) diff --git a/ddht/tools/v5_strategies.py b/ddht/tools/v5_strategies.py index a563b3b2..61c152c6 100644 --- a/ddht/tools/v5_strategies.py +++ b/ddht/tools/v5_strategies.py @@ -13,6 +13,7 @@ public_key_st = st.binary(min_size=32, max_size=32) node_id_st = st.binary(min_size=32, max_size=32) magic_st = st.binary(min_size=MAGIC_SIZE, max_size=MAGIC_SIZE) +iv_st = st.binary(min_size=16, max_size=16) id_nonce_st = st.binary(min_size=ID_NONCE_SIZE, max_size=ID_NONCE_SIZE) enr_seq_st = st.integers(min_value=0) diff --git a/ddht/v5/abc.py b/ddht/v5/abc.py index 1daf41ec..f2488d7a 100644 --- a/ddht/v5/abc.py +++ b/ddht/v5/abc.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import ( + Any, AsyncContextManager, AsyncIterable, Generic, @@ -60,7 +61,7 @@ def identity_scheme(self) -> Type[IdentitySchemeAPI]: @property @abstractmethod - def handshake_scheme(self) -> Type[HandshakeSchemeAPI]: + def handshake_scheme(self) -> Type[HandshakeSchemeAPI[Any]]: """ The handshake scheme used during the handshake. diff --git a/ddht/v5/constants.py b/ddht/v5/constants.py index 4add93cd..14e804c0 100644 --- a/ddht/v5/constants.py +++ b/ddht/v5/constants.py @@ -50,3 +50,5 @@ LOOKUP_PARALLELIZATION_FACTOR = 3 # number of parallel lookup requests (aka alpha) MAX_NODES_MESSAGE_TOTAL = 8 # max allowed total value for nodes messages + +ID_NONCE_SIGNATURE_PREFIX = b"discovery-id-nonce" diff --git a/ddht/v5/handshake.py b/ddht/v5/handshake.py index 3d84714e..eda0916a 100644 --- a/ddht/v5/handshake.py +++ b/ddht/v5/handshake.py @@ -1,5 +1,5 @@ import secrets -from typing import Optional, Type +from typing import Any, Optional, Type from eth_enr.abc import ENRAPI, IdentitySchemeAPI from eth_keys.datatypes import PublicKey @@ -8,9 +8,9 @@ from ddht.abc import HandshakeSchemeAPI, HandshakeSchemeRegistryAPI from ddht.exceptions import DecryptionError, HandshakeFailure -from ddht.handshake_schemes import default_handshake_scheme_registry from ddht.typing import AES128Key, IDNonce, Nonce from ddht.v5.abc import HandshakeParticipantAPI +from ddht.v5.handshake_schemes import v5_handshake_scheme_registry from ddht.v5.messages import BaseMessage from ddht.v5.packets import ( AuthHeaderPacket, @@ -26,7 +26,7 @@ class BaseHandshakeParticipant(HandshakeParticipantAPI): - _handshake_scheme_registry: HandshakeSchemeRegistryAPI = default_handshake_scheme_registry + _handshake_scheme_registry: HandshakeSchemeRegistryAPI = v5_handshake_scheme_registry def __init__( self, @@ -68,7 +68,7 @@ def tag(self) -> Tag: ) @property - def handshake_scheme(self) -> Type[HandshakeSchemeAPI]: + def handshake_scheme(self) -> Type[HandshakeSchemeAPI[Any]]: return self._handshake_scheme_registry[self.identity_scheme] @@ -133,15 +133,17 @@ def complete_handshake(self, response_packet: Packet) -> HandshakeResult: remote_public_key=remote_public_key_uncompressed, local_node_id=self.local_enr.node_id, remote_node_id=self.remote_node_id, - id_nonce=who_are_you_packet.id_nonce, + salt=who_are_you_packet.id_nonce, is_locally_initiated=True, ) # prepare response packet - id_nonce_signature = self.handshake_scheme.create_id_nonce_signature( + signature_inputs = self.handshake_scheme.signature_inputs_cls( id_nonce=who_are_you_packet.id_nonce, ephemeral_public_key=ephemeral_public_key, - private_key=self.local_private_key, + ) + id_nonce_signature = self.handshake_scheme.create_id_nonce_signature( + signature_inputs=signature_inputs, private_key=self.local_private_key, ) enr: Optional[ENRAPI] @@ -249,7 +251,7 @@ def complete_handshake(self, response_packet: Packet) -> HandshakeResult: remote_public_key=ephemeral_public_key, local_node_id=self.local_enr.node_id, remote_node_id=self.remote_node_id, - id_nonce=self.who_are_you_packet.id_nonce, + salt=self.who_are_you_packet.id_nonce, is_locally_initiated=False, ) @@ -309,11 +311,14 @@ def decrypt_and_validate_auth_response( current_remote_enr = enr + signature_inputs = self.handshake_scheme.signature_inputs_cls( + id_nonce=id_nonce, + ephemeral_public_key=auth_header_packet.auth_header.ephemeral_public_key, + ) try: self.handshake_scheme.validate_id_nonce_signature( + signature_inputs=signature_inputs, signature=id_nonce_signature, - id_nonce=id_nonce, - ephemeral_public_key=auth_header_packet.auth_header.ephemeral_public_key, public_key=current_remote_enr.public_key, ) except ValidationError as error: diff --git a/ddht/v5/handshake_schemes.py b/ddht/v5/handshake_schemes.py new file mode 100644 index 00000000..7036c2d3 --- /dev/null +++ b/ddht/v5/handshake_schemes.py @@ -0,0 +1,49 @@ +from hashlib import sha256 +from typing import NamedTuple + +from eth_keys.datatypes import PrivateKey + +from ddht.handshake_schemes import BaseV4HandshakeScheme, HandshakeSchemeRegistry +from ddht.typing import IDNonce +from ddht.v5.constants import ID_NONCE_SIGNATURE_PREFIX + +v5_handshake_scheme_registry = HandshakeSchemeRegistry() + + +class SignatureInputs(NamedTuple): + id_nonce: IDNonce + ephemeral_public_key: bytes + + +@v5_handshake_scheme_registry.register +class V4HandshakeScheme(BaseV4HandshakeScheme[SignatureInputs]): + signature_inputs_cls = SignatureInputs + + @classmethod + def create_id_nonce_signature( + cls, *, signature_inputs: SignatureInputs, private_key: bytes, + ) -> bytes: + private_key_object = PrivateKey(private_key) + signature_input = cls.create_id_nonce_signature_input( + signature_inputs=signature_inputs + ) + signature = private_key_object.sign_msg_hash_non_recoverable(signature_input) + return bytes(signature) + + @classmethod + def validate_id_nonce_signature( + cls, *, signature_inputs: SignatureInputs, signature: bytes, public_key: bytes, + ) -> None: + signature_input = cls.create_id_nonce_signature_input( + signature_inputs=signature_inputs + ) + cls.identity_scheme.validate_signature( + message_hash=signature_input, signature=signature, public_key=public_key + ) + + @classmethod + def create_id_nonce_signature_input( + cls, *, signature_inputs: SignatureInputs, + ) -> bytes: + preimage = b"".join((ID_NONCE_SIGNATURE_PREFIX,) + signature_inputs) + return sha256(preimage).digest() diff --git a/ddht/v5_1/abc.py b/ddht/v5_1/abc.py index 9d113ba6..92d82e85 100644 --- a/ddht/v5_1/abc.py +++ b/ddht/v5_1/abc.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod import logging from typing import ( + Any, AsyncContextManager, Collection, ContextManager, @@ -69,7 +70,7 @@ def identity_scheme(self) -> Type[IdentitySchemeAPI]: @property @abstractmethod - def handshake_scheme(self) -> Type[HandshakeSchemeAPI]: + def handshake_scheme(self) -> Type[HandshakeSchemeAPI[Any]]: ... @property @@ -204,11 +205,11 @@ async def send_message(self, message: AnyOutboundMessage) -> None: ... @abstractmethod - def get_free_request_id(self, node_id: NodeID) -> int: + def get_free_request_id(self, node_id: NodeID) -> bytes: ... @abstractmethod - def reserve_request_id(self, node_id: NodeID) -> ContextManager[int]: + def reserve_request_id(self, node_id: NodeID) -> ContextManager[bytes]: ... @abstractmethod @@ -247,8 +248,12 @@ async def wait_listening(self) -> None: @abstractmethod async def send_ping( - self, endpoint: Endpoint, node_id: NodeID, *, request_id: Optional[int] = None, - ) -> int: + self, + endpoint: Endpoint, + node_id: NodeID, + *, + request_id: Optional[bytes] = None, + ) -> bytes: ... # @@ -256,7 +261,7 @@ async def send_ping( # @abstractmethod async def send_pong( - self, endpoint: Endpoint, node_id: NodeID, *, request_id: int, + self, endpoint: Endpoint, node_id: NodeID, *, request_id: bytes, ) -> None: ... @@ -267,8 +272,8 @@ async def send_find_nodes( node_id: NodeID, *, distances: Collection[int], - request_id: Optional[int] = None, - ) -> int: + request_id: Optional[bytes] = None, + ) -> bytes: ... @abstractmethod @@ -278,7 +283,7 @@ async def send_found_nodes( node_id: NodeID, *, enrs: Sequence[ENRAPI], - request_id: int, + request_id: bytes, ) -> int: ... @@ -290,13 +295,13 @@ async def send_talk_request( *, protocol: bytes, payload: bytes, - request_id: Optional[int] = None, - ) -> int: + request_id: Optional[bytes] = None, + ) -> bytes: ... @abstractmethod async def send_talk_response( - self, endpoint: Endpoint, node_id: NodeID, *, payload: bytes, request_id: int, + self, endpoint: Endpoint, node_id: NodeID, *, payload: bytes, request_id: bytes, ) -> None: ... @@ -309,8 +314,8 @@ async def send_register_topic( topic: bytes, enr: ENRAPI, ticket: bytes = b"", - request_id: Optional[int] = None, - ) -> int: + request_id: Optional[bytes] = None, + ) -> bytes: ... @abstractmethod @@ -321,13 +326,13 @@ async def send_ticket( *, ticket: bytes, wait_time: int, - request_id: int, + request_id: bytes, ) -> None: ... @abstractmethod async def send_registration_confirmation( - self, endpoint: Endpoint, node_id: NodeID, *, topic: bytes, request_id: int, + self, endpoint: Endpoint, node_id: NodeID, *, topic: bytes, request_id: bytes, ) -> None: ... @@ -338,8 +343,8 @@ async def send_topic_query( node_id: NodeID, *, topic: bytes, - request_id: Optional[int] = None, - ) -> int: + request_id: Optional[bytes] = None, + ) -> bytes: ... # diff --git a/ddht/v5_1/app.py b/ddht/v5_1/app.py index edda21d8..043605e7 100644 --- a/ddht/v5_1/app.py +++ b/ddht/v5_1/app.py @@ -97,7 +97,7 @@ async def run(self) -> None: if self._boot_info.is_rpc_enabled: handlers = merge( - get_core_rpc_handlers(network.routing_table), + get_core_rpc_handlers(enr_manager.enr, network.routing_table), get_v51_rpc_handlers(network), ) rpc_server = RPCServer(self._boot_info.ipc_path, handlers) @@ -106,11 +106,12 @@ async def run(self) -> None: self.logger.info("Protocol-Version: %s", self._boot_info.protocol_version.value) self.logger.info("DDHT base dir: %s", self._boot_info.base_dir) self.logger.info("Starting discovery service...") - self.logger.info("Listening on %s:%d", listen_on, port) + self.logger.info("Listening on %s", listen_on) self.logger.info("Local Node ID: %s", encode_hex(enr_manager.enr.node_id)) self.logger.info( "Local ENR: seq=%d enr=%s", enr_manager.enr.sequence_number, enr_manager.enr ) self.manager.run_daemon_child_service(network) - self.manager.run_daemon_child_service(rpc_server) + + await self.manager.wait_finished() diff --git a/ddht/v5_1/client.py b/ddht/v5_1/client.py index 16f59b45..d3a2c3af 100644 --- a/ddht/v5_1/client.py +++ b/ddht/v5_1/client.py @@ -174,9 +174,9 @@ async def _do_listen(self, listen_on: Endpoint) -> None: # @contextmanager def _get_request_id( - self, node_id: NodeID, request_id: Optional[int] = None - ) -> Iterator[int]: - request_id_context: ContextManager[int] + self, node_id: NodeID, request_id: Optional[bytes] = None + ) -> Iterator[bytes]: + request_id_context: ContextManager[bytes] if request_id is None: request_id_context = self.dispatcher.reserve_request_id(node_id) @@ -187,8 +187,12 @@ def _get_request_id( yield message_request_id async def send_ping( - self, endpoint: Endpoint, node_id: NodeID, *, request_id: Optional[int] = None, - ) -> int: + self, + endpoint: Endpoint, + node_id: NodeID, + *, + request_id: Optional[bytes] = None, + ) -> bytes: with self._get_request_id(node_id, request_id) as message_request_id: message = AnyOutboundMessage( PingMessage(message_request_id, self.enr_manager.enr.sequence_number), @@ -200,7 +204,7 @@ async def send_ping( return message_request_id async def send_pong( - self, endpoint: Endpoint, node_id: NodeID, *, request_id: int, + self, endpoint: Endpoint, node_id: NodeID, *, request_id: bytes, ) -> None: message = AnyOutboundMessage( PongMessage( @@ -220,8 +224,8 @@ async def send_find_nodes( node_id: NodeID, *, distances: Collection[int], - request_id: Optional[int] = None, - ) -> int: + request_id: Optional[bytes] = None, + ) -> bytes: with self._get_request_id(node_id, request_id) as message_request_id: message = AnyOutboundMessage( FindNodeMessage(message_request_id, tuple(distances)), @@ -238,7 +242,7 @@ async def send_found_nodes( node_id: NodeID, *, enrs: Sequence[ENRAPI], - request_id: int, + request_id: bytes, ) -> int: enr_batches = partition_enrs( enrs, max_payload_size=FOUND_NODES_MAX_PAYLOAD_SIZE @@ -259,8 +263,8 @@ async def send_talk_request( *, protocol: bytes, payload: bytes, - request_id: Optional[int] = None, - ) -> int: + request_id: Optional[bytes] = None, + ) -> bytes: with self._get_request_id(node_id, request_id) as message_request_id: message = AnyOutboundMessage( TalkRequestMessage(message_request_id, protocol, payload), @@ -272,7 +276,7 @@ async def send_talk_request( return message_request_id async def send_talk_response( - self, endpoint: Endpoint, node_id: NodeID, *, payload: bytes, request_id: int, + self, endpoint: Endpoint, node_id: NodeID, *, payload: bytes, request_id: bytes, ) -> None: message = AnyOutboundMessage( TalkResponseMessage(request_id, payload), endpoint, node_id, @@ -287,8 +291,8 @@ async def send_register_topic( topic: bytes, enr: ENRAPI, ticket: bytes = b"", - request_id: Optional[int] = None, - ) -> int: + request_id: Optional[bytes] = None, + ) -> bytes: with self._get_request_id(node_id, request_id) as message_request_id: message = AnyOutboundMessage( RegisterTopicMessage(message_request_id, topic, enr, ticket), @@ -306,7 +310,7 @@ async def send_ticket( *, ticket: bytes, wait_time: int, - request_id: int, + request_id: bytes, ) -> None: message = AnyOutboundMessage( TicketMessage(request_id, ticket, wait_time), endpoint, node_id, @@ -314,7 +318,7 @@ async def send_ticket( await self.dispatcher.send_message(message) async def send_registration_confirmation( - self, endpoint: Endpoint, node_id: NodeID, *, topic: bytes, request_id: int, + self, endpoint: Endpoint, node_id: NodeID, *, topic: bytes, request_id: bytes, ) -> None: message = AnyOutboundMessage( RegistrationConfirmationMessage(request_id, topic), endpoint, node_id, @@ -327,8 +331,8 @@ async def send_topic_query( node_id: NodeID, *, topic: bytes, - request_id: Optional[int] = None, - ) -> int: + request_id: Optional[bytes] = None, + ) -> bytes: with self._get_request_id(node_id, request_id) as message_request_id: message = AnyOutboundMessage( TopicQueryMessage(message_request_id, topic), endpoint, node_id, @@ -343,6 +347,7 @@ async def ping( self, endpoint: Endpoint, node_id: NodeID ) -> InboundMessage[PongMessage]: with self._get_request_id(node_id) as request_id: + assert isinstance(request_id, bytes) request = AnyOutboundMessage( PingMessage(request_id, self.enr_manager.enr.sequence_number), endpoint, diff --git a/ddht/v5_1/constants.py b/ddht/v5_1/constants.py index c3543a20..d685dda6 100644 --- a/ddht/v5_1/constants.py +++ b/ddht/v5_1/constants.py @@ -13,5 +13,20 @@ DEFAULT_BOOTNODES: Tuple[str, ...] = ( - "enr:-IS4QAHCaxRnea2ypKUsy-Ldotp6pYtpqUfq0DfGBEeBhuiKDZg2lm12Bjt6KPUAccmPyFA5zkpT-ciVjt6zcxjue8cDgmlkgnY0gmlwhK3mkMaJc2VjcDI1NmsxoQJTqyNYyHmn7ibmkepLZVEuznjQeTGyAyH-xQLyL-6dZ4N1ZHCCdl8", # noqa: E501 + "enr:-IS4QKcAHi77_OQBuGolVX-I1dmQxyuZAsSTh3Z7Jck3LrzbYQ2NXzMEKvpit0cyH2coB55ddVDvKA8p5IUcg7DLQj4DgmlkgnY0gmlwhC1PW26Jc2VjcDI1NmsxoQPNz0D8sSVKyNTZuGRTTnPabutpJ8IUxpAyMqrVosZ14IN1ZHCCdl8", # noqa: E501 ) + + +PACKET_VERSION_1 = b"\x00\x01" + +ID_NONCE_SIGNATURE_PREFIX = b"discovery v5 identity proof" + +HEADER_PACKET_SIZE = 23 + +PROTOCOL_ID = b"discv5" + +WHO_ARE_YOU_PACKET_SIZE = 24 + +HANDSHAKE_HEADER_PACKET_SIZE = 34 + +MESSAGE_PACKET_SIZE = 32 diff --git a/ddht/v5_1/dispatcher.py b/ddht/v5_1/dispatcher.py index 1f0cd951..37725743 100644 --- a/ddht/v5_1/dispatcher.py +++ b/ddht/v5_1/dispatcher.py @@ -18,6 +18,7 @@ from async_service import Service from eth_enr import ENRDatabaseAPI from eth_typing import NodeID +from eth_utils import int_to_big_endian import trio from ddht._utils import humanize_node_id @@ -57,8 +58,8 @@ class _Subcription(NamedTuple): filter_by_node_id: Optional[NodeID] -def get_random_request_id() -> int: - return secrets.randbits(32) +def get_random_request_id() -> bytes: + return int_to_big_endian(secrets.randbits(32)) MAX_REQUEST_ID_ATTEMPTS = 3 @@ -127,8 +128,8 @@ class Dispatcher(Service, DispatcherAPI): _subscriptions: DefaultDict[int, Set[_Subcription]] - _reserved_request_ids: Set[Tuple[NodeID, int]] - _active_request_ids: Set[Tuple[NodeID, int]] + _reserved_request_ids: Set[Tuple[NodeID, bytes]] + _active_request_ids: Set[Tuple[NodeID, bytes]] def __init__( self, @@ -341,7 +342,18 @@ def _get_sessions_for_inbound_envelope( ) if ( not session.is_after_handshake - or session.remote_node_id == envelope.packet.header.source_node_id + or ( + ( + envelope.packet.is_message + and session.remote_node_id + == envelope.packet.auth_data.source_node_id # type: ignore # noqa: E501 + ) + or ( + envelope.packet.is_handshake + and session.remote_node_id + == envelope.packet.auth_data.auth_data_head.source_node_id # type: ignore + ) + ) ) ) @@ -385,9 +397,10 @@ def _get_sessions_for_outbound_message( # # Utility # - def get_free_request_id(self, node_id: NodeID) -> int: + def get_free_request_id(self, node_id: NodeID) -> bytes: for _ in range(MAX_REQUEST_ID_ATTEMPTS): request_id = get_random_request_id() + if (node_id, request_id) in self._reserved_request_ids: continue elif (node_id, request_id) in self._active_request_ids: @@ -404,7 +417,7 @@ def get_free_request_id(self, node_id: NodeID) -> int: ) @contextmanager - def reserve_request_id(self, node_id: NodeID) -> Iterator[int]: + def reserve_request_id(self, node_id: NodeID) -> Iterator[bytes]: request_id = self.get_free_request_id(node_id) try: self._reserved_request_ids.add((node_id, request_id)) @@ -450,7 +463,7 @@ async def subscribe_request( request_id = request.message.request_id self.logger.debug( - "Sending request: %s with request id %d", request, request_id, + "Sending request: %s with request id %s", request, request_id.hex(), ) send_channel, receive_channel = trio.open_memory_channel[TMessage](256) @@ -489,7 +502,7 @@ async def _manage_request_response( ) async with subscription_ctx as subscription: self.logger.debug( - "Sending request with request id %d", request_id, + "Sending request with request id %s", request_id.hex(), ) # Send the request await self.send_message(request) diff --git a/ddht/v5_1/handshake_schemes.py b/ddht/v5_1/handshake_schemes.py new file mode 100644 index 00000000..e453fb8b --- /dev/null +++ b/ddht/v5_1/handshake_schemes.py @@ -0,0 +1,62 @@ +from hashlib import sha256 +from typing import NamedTuple + +from eth_keys.datatypes import PrivateKey +from eth_typing import NodeID + +from ddht.handshake_schemes import BaseV4HandshakeScheme, HandshakeSchemeRegistry +from ddht.v5_1.constants import ID_NONCE_SIGNATURE_PREFIX +from ddht.v5_1.packets import Header, WhoAreYouPacket + +v51_handshake_scheme_registry = HandshakeSchemeRegistry() + + +class SignatureInputs(NamedTuple): + iv: bytes + header: Header + who_are_you: WhoAreYouPacket + ephemeral_public_key: bytes + recipient_node_id: NodeID + + +@v51_handshake_scheme_registry.register +class V4HandshakeScheme(BaseV4HandshakeScheme[SignatureInputs]): + signature_inputs_cls = SignatureInputs + + @classmethod + def create_id_nonce_signature( + cls, *, signature_inputs: SignatureInputs, private_key: bytes, + ) -> bytes: + private_key_object = PrivateKey(private_key) + signature_input = cls.create_id_nonce_signature_input( + signature_inputs=signature_inputs + ) + signature = private_key_object.sign_msg_hash_non_recoverable(signature_input) + return bytes(signature) + + @classmethod + def validate_id_nonce_signature( + cls, *, signature_inputs: SignatureInputs, signature: bytes, public_key: bytes, + ) -> None: + signature_input = cls.create_id_nonce_signature_input( + signature_inputs=signature_inputs + ) + cls.identity_scheme.validate_signature( + message_hash=signature_input, signature=signature, public_key=public_key + ) + + @classmethod + def create_id_nonce_signature_input( + cls, *, signature_inputs: SignatureInputs, + ) -> bytes: + preimage = b"".join( + ( + ID_NONCE_SIGNATURE_PREFIX, + signature_inputs.iv, + signature_inputs.header.to_wire_bytes(), + signature_inputs.who_are_you.to_wire_bytes(), + signature_inputs.ephemeral_public_key, + signature_inputs.recipient_node_id, + ) + ) + return sha256(preimage).digest() diff --git a/ddht/v5_1/messages.py b/ddht/v5_1/messages.py index b3f76d74..14beca80 100644 --- a/ddht/v5_1/messages.py +++ b/ddht/v5_1/messages.py @@ -1,12 +1,16 @@ +from typing import cast + from eth_enr.sedes import ENRSedes +import rlp from rlp.sedes import Binary, CountableList, big_endian_int, binary from ddht.base_message import BaseMessage +from ddht.encryption import aesgcm_decrypt from ddht.message_registry import MessageTypeRegistry from ddht.sedes import ip_address_sedes -from ddht.v5.constants import TOPIC_HASH_SIZE +from ddht.typing import AES128Key, Nonce -topic_sedes = Binary.fixed_length(TOPIC_HASH_SIZE) +topic_sedes = Binary.fixed_length(32) v51_registry = MessageTypeRegistry() @@ -19,7 +23,7 @@ class PingMessage(BaseMessage): message_type = 1 - fields = (("request_id", big_endian_int), ("enr_seq", big_endian_int)) + fields = (("request_id", binary), ("enr_seq", big_endian_int)) @v51_registry.register @@ -27,7 +31,7 @@ class PongMessage(BaseMessage): message_type = 2 fields = ( - ("request_id", big_endian_int), + ("request_id", binary), ("enr_seq", big_endian_int), ("packet_ip", ip_address_sedes), ("packet_port", big_endian_int), @@ -39,7 +43,7 @@ class FindNodeMessage(BaseMessage): message_type = 3 fields = ( - ("request_id", big_endian_int), + ("request_id", binary), ("distances", CountableList(big_endian_int)), ) @@ -49,7 +53,7 @@ class FoundNodesMessage(BaseMessage): message_type = 4 fields = ( - ("request_id", big_endian_int), + ("request_id", binary), ("total", big_endian_int), ("enrs", CountableList(ENRSedes)), ) @@ -59,7 +63,10 @@ class FoundNodesMessage(BaseMessage): class TalkRequestMessage(BaseMessage): message_type = 5 - fields = (("request_id", big_endian_int), ("protocol", binary), ("payload", binary)) + protocol: bytes + payload: bytes + + fields = (("request_id", binary), ("protocol", binary), ("payload", binary)) @v51_registry.register @@ -68,7 +75,7 @@ class TalkResponseMessage(BaseMessage): payload: bytes - fields = (("request_id", big_endian_int), ("payload", binary)) + fields = (("request_id", binary), ("payload", binary)) @v51_registry.register @@ -76,7 +83,7 @@ class RegisterTopicMessage(BaseMessage): message_type = 7 fields = ( - ("request_id", big_endian_int), + ("request_id", binary), ("topic", topic_sedes), ("enr", ENRSedes), ("ticket", binary), @@ -88,7 +95,7 @@ class TicketMessage(BaseMessage): message_type = 8 fields = ( - ("request_id", big_endian_int), + ("request_id", binary), ("ticket", binary), ("wait_time", big_endian_int), ) @@ -98,11 +105,31 @@ class TicketMessage(BaseMessage): class RegistrationConfirmationMessage(BaseMessage): message_type = 9 - fields = (("request_id", big_endian_int), ("topic", binary)) + fields = (("request_id", binary), ("topic", binary)) @v51_registry.register class TopicQueryMessage(BaseMessage): message_type = 10 - fields = (("request_id", big_endian_int), ("topic", topic_sedes)) + fields = (("request_id", binary), ("topic", topic_sedes)) + + +def decode_message( + decryption_key: AES128Key, + aes_gcm_nonce: Nonce, + message_cipher_text: bytes, + authenticated_data: bytes, + message_type_registry: MessageTypeRegistry = v51_registry, +) -> BaseMessage: + message_plain_text = aesgcm_decrypt( + key=decryption_key, + nonce=aes_gcm_nonce, + cipher_text=message_cipher_text, + authenticated_data=authenticated_data, + ) + message_type = message_plain_text[0] + message_sedes = message_type_registry[message_type] + message = rlp.decode(message_plain_text[1:], sedes=message_sedes) + + return cast(BaseMessage, message) diff --git a/ddht/v5_1/packets.py b/ddht/v5_1/packets.py index fe3d252f..15b34c6a 100644 --- a/ddht/v5_1/packets.py +++ b/ddht/v5_1/packets.py @@ -14,89 +14,83 @@ from ddht.encryption import aesctr_decrypt_stream, aesctr_encrypt, aesgcm_encrypt from ddht.exceptions import DecodingError from ddht.typing import AES128Key, IDNonce, Nonce - -PROTOCOL_ID = b"discv5 " - +from ddht.v5_1.constants import ( + HANDSHAKE_HEADER_PACKET_SIZE, + HEADER_PACKET_SIZE, + MESSAGE_PACKET_SIZE, + PACKET_VERSION_1, + PROTOCOL_ID, + WHO_ARE_YOU_PACKET_SIZE, +) UINT8_TO_BYTES = {v: bytes([v]) for v in range(256)} @dataclass(frozen=True) class MessagePacket: - aes_gcm_nonce: Nonce # 96 bit AES/GCM nonce + source_node_id: NodeID flag: int = field(init=False, repr=False, default=0) def to_wire_bytes(self) -> bytes: - return self.aes_gcm_nonce + return self.source_node_id @classmethod def from_wire_bytes(cls, data: bytes) -> "MessagePacket": - if len(data) != 12: + if len(data) != MESSAGE_PACKET_SIZE: raise DecodingError( f"Invalid length for MessagePacket: length={len(data)} data={data.hex()}" ) - return cls(cast(Nonce, data)) + return cls(NodeID(data)) @dataclass(frozen=True) class WhoAreYouPacket: - request_nonce: Nonce # uint96 - id_nonce: IDNonce # uint256 + id_nonce: IDNonce # uint128 enr_sequence_number: int # uint64 flag: int = field(init=False, repr=False, default=1) def to_wire_bytes(self) -> bytes: - return b"".join( - ( - self.request_nonce, - self.id_nonce, - struct.pack(">Q", self.enr_sequence_number), - ) - ) + return b"".join((self.id_nonce, struct.pack(">Q", self.enr_sequence_number),)) @classmethod def from_wire_bytes(cls, data: bytes) -> "WhoAreYouPacket": - if len(data) != 52: + if len(data) != WHO_ARE_YOU_PACKET_SIZE: raise DecodingError( f"Invalid length for WhoAreYouPacket: length={len(data)} data={data.hex()}" ) stream = BytesIO(data) - request_nonce = cast(Nonce, stream.read(12)) - id_nonce = cast(IDNonce, stream.read(32)) + id_nonce = cast(IDNonce, stream.read(16)) enr_sequence_number = int.from_bytes(stream.read(8), "big") - return cls(request_nonce, id_nonce, enr_sequence_number) + return cls(id_nonce, enr_sequence_number) class HandshakeHeader(NamedTuple): - version: int # uint8 (1 for v4) + source_node_id: NodeID # bytes32 signature_size: int # uint8 (64 for v4) ephemeral_key_size: int # uint8 (33 for v4) def to_wire_bytes(self) -> bytes: return b"".join( ( - UINT8_TO_BYTES[self.version], + self.source_node_id, UINT8_TO_BYTES[self.signature_size], UINT8_TO_BYTES[self.ephemeral_key_size], ) ) - @classmethod - def v4_header(cls) -> "HandshakeHeader": - return cls(1, 64, 33) - @classmethod def from_wire_bytes(cls, data: bytes) -> "HandshakeHeader": - if len(data) != 3: + if len(data) != HANDSHAKE_HEADER_PACKET_SIZE: raise DecodingError( f"Invalid length for HandshakeHeader: length={len(data)} data={data.hex()}" ) - version = data[0] - signature_size = data[1] - ephemeral_key_size = data[2] - return cls(version, signature_size, ephemeral_key_size) + stream = BytesIO(data) + source_node_id = NodeID(stream.read(32)) + signature_size = stream.read(1)[0] + ephemeral_key_size = stream.read(1)[0] + return cls(source_node_id, signature_size, ephemeral_key_size) @dataclass(frozen=True) @@ -121,9 +115,13 @@ def to_wire_bytes(self) -> bytes: @classmethod def from_wire_bytes(cls, data: bytes) -> "HandshakePacket": stream = BytesIO(data) - auth_data_head = HandshakeHeader.from_wire_bytes(stream.read(3)) + auth_data_head = HandshakeHeader.from_wire_bytes( + stream.read(HANDSHAKE_HEADER_PACKET_SIZE) + ) expected_length = ( - 3 + auth_data_head.signature_size + auth_data_head.ephemeral_key_size + HANDSHAKE_HEADER_PACKET_SIZE + + auth_data_head.signature_size + + auth_data_head.ephemeral_key_size ) if len(data) < expected_length: raise DecodingError( @@ -148,32 +146,41 @@ def from_wire_bytes(cls, data: bytes) -> "HandshakePacket": class Header(NamedTuple): - protocol_id: bytes - source_node_id: NodeID + protocol_id: bytes # bytes6 + version: bytes # bytes2 flag: int # uint8 + aes_gcm_nonce: Nonce # uint96 auth_data_size: int # uint16 def to_wire_bytes(self) -> bytes: return b"".join( ( self.protocol_id, - self.source_node_id, + self.version, UINT8_TO_BYTES[self.flag], + self.aes_gcm_nonce, self.auth_data_size.to_bytes(2, "big"), ) ) @classmethod def from_wire_bytes(cls, data: bytes) -> "Header": - if len(data) != 43: + if len(data) != HEADER_PACKET_SIZE: raise DecodingError( - f"Invalid length for Header: length={len(data)} data={data.hex()}" + f"Invalid length for Header: actual={len(data)} " + f"expected={HEADER_PACKET_SIZE} data={data.hex()}" ) - protocol_id = data[:8] - remote_node_id = cast(NodeID, data[8:40]) - flag = data[40] - auth_data_size = int.from_bytes(data[41:43], "big") - return cls(protocol_id, remote_node_id, flag, auth_data_size) + stream = BytesIO(data) + protocol_id = stream.read(6) + if protocol_id != PROTOCOL_ID: + raise DecodingError(f"Invalid protocol: {protocol_id!r}") + version = stream.read(2) + if version != b"\x00\x01": + raise DecodingError(f"Unsupported version: {version!r}") + flag = stream.read(1)[0] + aes_gcm_nonce = Nonce(stream.read(12)) + auth_data_size = int.from_bytes(stream.read(2), "big") + return cls(protocol_id, version, flag, aes_gcm_nonce, auth_data_size) AuthData = Union[MessagePacket, WhoAreYouPacket, HandshakePacket] @@ -182,6 +189,7 @@ def from_wire_bytes(cls, data: bytes) -> "Header": @dataclass(frozen=True) class Packet(Generic[TAuthData]): + iv: bytes header: Header auth_data: TAuthData message_cipher_text: bytes @@ -190,8 +198,9 @@ class Packet(Generic[TAuthData]): def __str__(self) -> str: return ( f"Packet[{self.auth_data.__class__.__name__}]" - f"(header={self.header}, auth_data={self.auth_data}, " - f"message_cipher_text={self.message_cipher_text!r})" + f"(iv={self.iv!r}, header={self.header}, auth_data={self.auth_data}, " + f"message_cipher_text={self.message_cipher_text!r}, " + f"dest_node_id={self.dest_node_id!r})" ) @property @@ -206,28 +215,44 @@ def is_who_are_you(self) -> bool: def is_handshake(self) -> bool: return type(self.auth_data) is HandshakePacket + @property + def challenge_data(self) -> bytes: + return b"".join( + (self.iv, self.header.to_wire_bytes(), self.auth_data.to_wire_bytes(),) + ) + @classmethod def prepare( cls, *, - nonce: Nonce, + aes_gcm_nonce: Nonce, initiator_key: AES128Key, message: BaseMessage, auth_data: TAuthData, - source_node_id: NodeID, dest_node_id: NodeID, protocol_id: bytes = PROTOCOL_ID, + iv: Optional[bytes] = None, ) -> "Packet[TAuthData]": + if iv is None: + iv = secrets.token_bytes(16) auth_data_bytes = auth_data.to_wire_bytes() auth_data_size = len(auth_data_bytes) - header = Header(protocol_id, source_node_id, auth_data.flag, auth_data_size,) + header = Header( + protocol_id, + PACKET_VERSION_1, + auth_data.flag, + aes_gcm_nonce, + auth_data_size, + ) + authenticated_data = b"".join((iv, header.to_wire_bytes(), auth_data_bytes,)) message_cipher_text = aesgcm_encrypt( key=initiator_key, - nonce=nonce, + nonce=aes_gcm_nonce, plain_text=message.to_bytes(), - authenticated_data=header.to_wire_bytes() + auth_data.to_wire_bytes(), + authenticated_data=authenticated_data, ) return cls( + iv=iv, header=header, auth_data=auth_data, message_cipher_text=message_cipher_text, @@ -237,11 +262,10 @@ def prepare( def to_wire_bytes(self) -> bytes: auth_data_bytes = self.auth_data.to_wire_bytes() header_wire_bytes = self.header.to_wire_bytes() - plain_header = header_wire_bytes + auth_data_bytes - masking_key = cast(AES128Key, self.dest_node_id[:16]) - masking_iv = secrets.token_bytes(16) - masked_header = aesctr_encrypt(masking_key, masking_iv, plain_header) - return b"".join((masking_iv, masked_header, self.message_cipher_text,)) + header_plaintext = header_wire_bytes + auth_data_bytes + masking_key = AES128Key(self.dest_node_id[:16]) + masked_header = aesctr_encrypt(masking_key, self.iv, header_plaintext) + return b"".join((self.iv, masked_header, self.message_cipher_text,)) AnyPacket = Union[ @@ -255,7 +279,7 @@ def decode_packet(data: bytes, local_node_id: NodeID,) -> AnyPacket: cipher_text_stream = aesctr_decrypt_stream(masking_key, iv, data[16:]) # Decode the header - header_bytes = bytes(take(43, cipher_text_stream)) + header_bytes = bytes(take(HEADER_PACKET_SIZE, cipher_text_stream)) header = Header.from_wire_bytes(header_bytes) auth_data_bytes = bytes(take(header.auth_data_size, cipher_text_stream)) @@ -269,9 +293,9 @@ def decode_packet(data: bytes, local_node_id: NodeID,) -> AnyPacket: else: raise DecodingError(f"Unable to decode datagram: {data.hex()}", data) - message_cipher_text = data[16 + 43 + header.auth_data_size :] + message_cipher_text = data[16 + HEADER_PACKET_SIZE + header.auth_data_size :] return cast( AnyPacket, - Packet(header, auth_data, message_cipher_text, dest_node_id=local_node_id), + Packet(iv, header, auth_data, message_cipher_text, dest_node_id=local_node_id), ) diff --git a/ddht/v5_1/session.py b/ddht/v5_1/session.py index 2dc0eb4a..ff1a0fcc 100644 --- a/ddht/v5_1/session.py +++ b/ddht/v5_1/session.py @@ -3,35 +3,31 @@ import itertools import logging import secrets -from typing import Optional, Tuple, Type, cast +from typing import Any, Optional, Tuple, Type, cast import uuid from eth_enr import ENRAPI, ENRDatabaseAPI, IdentitySchemeAPI from eth_keys import keys from eth_typing import NodeID from eth_utils import ValidationError -import rlp import trio from ddht._utils import humanize_node_id from ddht.abc import HandshakeSchemeAPI, HandshakeSchemeRegistryAPI from ddht.base_message import AnyInboundMessage, AnyOutboundMessage, BaseMessage -from ddht.encryption import aesgcm_decrypt from ddht.endpoint import Endpoint from ddht.exceptions import DecryptionError, HandshakeFailure -from ddht.handshake_schemes import default_handshake_scheme_registry from ddht.message_registry import MessageTypeRegistry from ddht.typing import AES128Key, IDNonce, Nonce, SessionKeys from ddht.v5_1.abc import EventsAPI, SessionAPI from ddht.v5_1.constants import SESSION_IDLE_TIMEOUT from ddht.v5_1.envelope import InboundEnvelope, OutboundEnvelope from ddht.v5_1.events import Events -from ddht.v5_1.messages import v51_registry +from ddht.v5_1.handshake_schemes import v51_handshake_scheme_registry +from ddht.v5_1.messages import decode_message, v51_registry from ddht.v5_1.packets import ( - AuthData, HandshakeHeader, HandshakePacket, - Header, MessagePacket, Packet, WhoAreYouPacket, @@ -53,7 +49,7 @@ class BaseSession(SessionAPI): logger = logging.getLogger("ddht.session.Session") _last_message_received_at: float - _handshake_scheme_registry: HandshakeSchemeRegistryAPI = default_handshake_scheme_registry + _handshake_scheme_registry: HandshakeSchemeRegistryAPI = v51_handshake_scheme_registry def __init__( self, @@ -150,7 +146,7 @@ def identity_scheme(self) -> Type[IdentitySchemeAPI]: return self.local_enr.identity_scheme @property - def handshake_scheme(self) -> Type[HandshakeSchemeAPI]: + def handshake_scheme(self) -> Type[HandshakeSchemeAPI[Any]]: return self._handshake_scheme_registry[self.identity_scheme] @property @@ -172,35 +168,14 @@ async def _process_message_buffers(self) -> None: ... def decode_message(self, packet: Packet[MessagePacket]) -> BaseMessage: - return self._decode_message( + return decode_message( self.keys.decryption_key, - packet.header, - packet.auth_data, - packet.auth_data.aes_gcm_nonce, + packet.header.aes_gcm_nonce, packet.message_cipher_text, + packet.challenge_data, + self._message_type_registry, ) - def _decode_message( - self, - decryption_key: AES128Key, - header: Header, - auth_data: AuthData, - nonce: Nonce, - message_cipher_text: bytes, - ) -> BaseMessage: - authenticated_data = header.to_wire_bytes() + auth_data.to_wire_bytes() - message_plain_text = aesgcm_decrypt( - key=decryption_key, - nonce=nonce, - cipher_text=message_cipher_text, - authenticated_data=authenticated_data, - ) - message_type = message_plain_text[0] - message_sedes = self._message_type_registry[message_type] - message = rlp.decode(message_plain_text[1:], sedes=message_sedes) - - return cast(BaseMessage, message) - def get_encryption_nonce(self) -> Nonce: return Nonce( next(self._nonce_counter).to_bytes(4, "big") + secrets.token_bytes(8) @@ -210,13 +185,12 @@ def prepare_envelope(self, message: AnyOutboundMessage) -> OutboundEnvelope: if not self.is_after_handshake: raise Exception("Invalid") nonce = self.get_encryption_nonce() - auth_data = MessagePacket(aes_gcm_nonce=nonce) + auth_data = MessagePacket(source_node_id=self._local_node_id) packet = Packet.prepare( - nonce=nonce, + aes_gcm_nonce=nonce, initiator_key=self.keys.encryption_key, message=message.message, auth_data=auth_data, - source_node_id=self._local_node_id, dest_node_id=self.remote_node_id, ) outbound_envelope = OutboundEnvelope(packet, self.remote_endpoint) @@ -368,11 +342,10 @@ async def handle_inbound_envelope(self, envelope: InboundEnvelope) -> bool: async def _send_handshake_initiation(self) -> None: self._initiating_packet = Packet.prepare( - nonce=cast(Nonce, secrets.token_bytes(12)), + aes_gcm_nonce=cast(Nonce, secrets.token_bytes(12)), initiator_key=cast(AES128Key, secrets.token_bytes(16)), message=RandomMessage(), - auth_data=MessagePacket(aes_gcm_nonce=cast(Nonce, secrets.token_bytes(12))), - source_node_id=self._local_node_id, + auth_data=MessagePacket(source_node_id=self._local_node_id), dest_node_id=self.remote_node_id, ) envelope = OutboundEnvelope( @@ -386,6 +359,13 @@ async def _receive_handshake_response( ) -> Tuple[SessionKeys, bytes]: self.logger.debug("%s: receiving handshake response", self) + if packet.header.aes_gcm_nonce != self._initiating_packet.header.aes_gcm_nonce: + raise HandshakeFailure( + f"WhoAreYou packet nonce does not match request nonce: " + f"expected={self._initiating_packet.header.aes_gcm_nonce} " + f"actual={packet.header.aes_gcm_nonce}" + ) + # compute session keys ephemeral_private_key = keys.PrivateKey(secrets.token_bytes(32)) @@ -395,7 +375,7 @@ async def _receive_handshake_response( remote_public_key=remote_enr.public_key, local_node_id=self._local_node_id, remote_node_id=self.remote_node_id, - id_nonce=packet.auth_data.id_nonce, + salt=packet.challenge_data, is_locally_initiated=True, ) @@ -412,14 +392,23 @@ async def _send_handshake_completion( local_enr = self.local_enr # prepare response packet - id_nonce_signature = self.handshake_scheme.create_id_nonce_signature( - id_nonce=packet.auth_data.id_nonce, + signature_inputs = self.handshake_scheme.signature_inputs_cls( + iv=packet.iv, + header=packet.header, + who_are_you=packet.auth_data, ephemeral_public_key=ephemeral_public_key, - private_key=self._local_private_key, + recipient_node_id=self.remote_node_id, + ) + id_nonce_signature = self.handshake_scheme.create_id_nonce_signature( + signature_inputs=signature_inputs, private_key=self._local_private_key, ) auth_data = HandshakePacket( - auth_data_head=HandshakeHeader.v4_header(), + auth_data_head=HandshakeHeader( + source_node_id=self._local_node_id, + signature_size=len(id_nonce_signature), + ephemeral_key_size=len(ephemeral_public_key), + ), id_signature=id_nonce_signature, ephemeral_public_key=ephemeral_public_key, record=( @@ -432,11 +421,10 @@ async def _send_handshake_completion( ), ) handshake_packet = Packet.prepare( - nonce=packet.auth_data.request_nonce, + aes_gcm_nonce=packet.header.aes_gcm_nonce, initiator_key=self.keys.encryption_key, message=self._initial_message.message, auth_data=auth_data, - source_node_id=self._local_node_id, dest_node_id=self._remote_node_id, ) @@ -556,7 +544,10 @@ async def handle_inbound_envelope(self, envelope: InboundEnvelope) -> bool: ) except HandshakeFailure: self.logger.debug( - "%s: Discarding invalid Handshake packet: %s", self, envelope + "%s: Discarding invalid Handshake packet: %s", + self, + envelope, + exc_info=True, ) await self._events.packet_discarded.trigger((self, envelope)) return False @@ -597,7 +588,7 @@ async def handle_inbound_envelope(self, envelope: InboundEnvelope) -> bool: if envelope.packet.is_message: self.logger.debug("%s: received handshake initiation", self) self._status = SessionStatus.DURING - self._remote_node_id = envelope.packet.header.source_node_id + self._remote_node_id = envelope.packet.auth_data.source_node_id # type: ignore await self._send_handshake_response( cast(Packet[MessagePacket], envelope.packet), envelope.sender_endpoint, @@ -617,25 +608,23 @@ async def _send_handshake_response( ) -> None: self.logger.debug("%s: sending handshake response", self) try: - remote_enr = self._enr_db.get_enr(packet.header.source_node_id) + remote_enr = self._enr_db.get_enr(packet.auth_data.source_node_id) except KeyError: enr_sequence_number = 0 else: enr_sequence_number = remote_enr.sequence_number auth_data = WhoAreYouPacket( - request_nonce=packet.auth_data.aes_gcm_nonce, - id_nonce=cast(IDNonce, secrets.token_bytes(32)), + id_nonce=cast(IDNonce, secrets.token_bytes(16)), enr_sequence_number=enr_sequence_number, ) self.handshake_response_packet = Packet.prepare( - nonce=cast(Nonce, secrets.token_bytes(12)), + aes_gcm_nonce=packet.header.aes_gcm_nonce, initiator_key=cast(AES128Key, secrets.token_bytes(16)), message=EmptyMessage(), auth_data=auth_data, - source_node_id=self._local_node_id, - dest_node_id=packet.header.source_node_id, + dest_node_id=packet.auth_data.source_node_id, ) envelope = OutboundEnvelope( packet=self.handshake_response_packet, receiver_endpoint=sender_endpoint, @@ -658,12 +647,18 @@ async def _receive_handshake_completion( remote_enr = self._enr_db.get_enr(self.remote_node_id) handshake_scheme = self.handshake_scheme + signature_inputs = handshake_scheme.signature_inputs_cls( + iv=self.handshake_response_packet.iv, + header=self.handshake_response_packet.header, + who_are_you=self.handshake_response_packet.auth_data, + ephemeral_public_key=packet.auth_data.ephemeral_public_key, + recipient_node_id=self._local_node_id, + ) # Verify the id_nonce_signature which ensures that the remote node has # not lied about their node_id try: handshake_scheme.validate_id_nonce_signature( - id_nonce=self.handshake_response_packet.auth_data.id_nonce, - ephemeral_public_key=packet.auth_data.ephemeral_public_key, + signature_inputs=signature_inputs, signature=packet.auth_data.id_signature, public_key=remote_enr.public_key, ) @@ -675,16 +670,16 @@ async def _receive_handshake_completion( remote_public_key=packet.auth_data.ephemeral_public_key, local_node_id=self._local_node_id, remote_node_id=self.remote_node_id, - id_nonce=self.handshake_response_packet.auth_data.id_nonce, + salt=self.handshake_response_packet.challenge_data, is_locally_initiated=False, ) - message = self._decode_message( + message = decode_message( session_keys.decryption_key, - packet.header, - packet.auth_data, - self.handshake_response_packet.auth_data.request_nonce, + packet.header.aes_gcm_nonce, packet.message_cipher_text, + packet.challenge_data, + self._message_type_registry, ) await self._inbound_message_send_channel.send( diff --git a/tests/core/test_handshake_schemes.py b/tests/core/v5/test_handshake_schemes.py similarity index 89% rename from tests/core/test_handshake_schemes.py rename to tests/core/v5/test_handshake_schemes.py index 7aadcd8c..0e9e7b6f 100644 --- a/tests/core/test_handshake_schemes.py +++ b/tests/core/v5/test_handshake_schemes.py @@ -5,13 +5,10 @@ from hypothesis import given import pytest -from ddht.constants import ID_NONCE_SIGNATURE_PREFIX -from ddht.handshake_schemes import ( - V4HandshakeScheme, - ecdh_agree, - hkdf_expand_and_extract, -) +from ddht.handshake_schemes import ecdh_agree, hkdf_expand_and_extract from ddht.tools.v5_strategies import id_nonce_st, private_key_st +from ddht.v5.constants import ID_NONCE_SIGNATURE_PREFIX +from ddht.v5.handshake_schemes import SignatureInputs, V4HandshakeScheme def test_handshake_key_generation(): @@ -39,9 +36,8 @@ def test_handshake_public_key_validation_invalid(public_key): def test_id_nonce_signing(private_key, id_nonce, ephemeral_key): ephemeral_public_key = PrivateKey(ephemeral_key).public_key.to_bytes() signature = V4HandshakeScheme.create_id_nonce_signature( - id_nonce=id_nonce, + signature_inputs=SignatureInputs(id_nonce, ephemeral_public_key), private_key=private_key, - ephemeral_public_key=ephemeral_public_key, ) signature_object = NonRecoverableSignature(signature) message_hash = sha256( @@ -56,14 +52,12 @@ def test_id_nonce_signing(private_key, id_nonce, ephemeral_key): def test_valid_id_nonce_signature_validation(private_key, id_nonce, ephemeral_key): ephemeral_public_key = PrivateKey(ephemeral_key).public_key.to_bytes() signature = V4HandshakeScheme.create_id_nonce_signature( - id_nonce=id_nonce, + signature_inputs=SignatureInputs(id_nonce, ephemeral_public_key), private_key=private_key, - ephemeral_public_key=ephemeral_public_key, ) public_key = PrivateKey(private_key).public_key.to_compressed_bytes() V4HandshakeScheme.validate_id_nonce_signature( - id_nonce=id_nonce, - ephemeral_public_key=ephemeral_public_key, + signature_inputs=SignatureInputs(id_nonce, ephemeral_public_key), signature=signature, public_key=public_key, ) @@ -74,8 +68,7 @@ def test_invalid_id_nonce_signature_validation(): private_key = b"\x11" * 32 ephemeral_public_key = b"\x22" * 64 signature = V4HandshakeScheme.create_id_nonce_signature( - id_nonce=id_nonce, - ephemeral_public_key=ephemeral_public_key, + signature_inputs=SignatureInputs(id_nonce, ephemeral_public_key), private_key=private_key, ) @@ -85,27 +78,25 @@ def test_invalid_id_nonce_signature_validation(): different_ephemeral_public_key = b"\x00" * 64 assert different_public_key != public_key assert different_id_nonce != id_nonce + assert different_ephemeral_public_key != ephemeral_public_key with pytest.raises(ValidationError): V4HandshakeScheme.validate_id_nonce_signature( - id_nonce=id_nonce, - ephemeral_public_key=ephemeral_public_key, + signature_inputs=SignatureInputs(id_nonce, ephemeral_public_key), signature=signature, public_key=different_public_key, ) with pytest.raises(ValidationError): V4HandshakeScheme.validate_id_nonce_signature( - id_nonce=different_id_nonce, - ephemeral_public_key=ephemeral_public_key, + signature_inputs=SignatureInputs(different_id_nonce, ephemeral_public_key), signature=signature, public_key=public_key, ) with pytest.raises(ValidationError): V4HandshakeScheme.validate_id_nonce_signature( - id_nonce=id_nonce, - ephemeral_public_key=different_ephemeral_public_key, + signature_inputs=SignatureInputs(id_nonce, different_ephemeral_public_key), signature=signature, public_key=public_key, ) @@ -131,7 +122,7 @@ def test_session_key_derivation(initiator_private_key, recipient_private_key, id remote_public_key=recipient_public_key, local_node_id=initiator_node_id, remote_node_id=recipient_node_id, - id_nonce=id_nonce, + salt=id_nonce, is_locally_initiated=True, ) recipient_session_keys = V4HandshakeScheme.compute_session_keys( @@ -139,7 +130,7 @@ def test_session_key_derivation(initiator_private_key, recipient_private_key, id remote_public_key=initiator_public_key, local_node_id=recipient_node_id, remote_node_id=initiator_node_id, - id_nonce=id_nonce, + salt=id_nonce, is_locally_initiated=False, ) @@ -248,8 +239,7 @@ def test_official_id_nonce_signature( id_nonce, ephemeral_public_key, local_secret_key, id_nonce_signature ): created_signature = V4HandshakeScheme.create_id_nonce_signature( - id_nonce=id_nonce, - ephemeral_public_key=ephemeral_public_key, + signature_inputs=SignatureInputs(id_nonce, ephemeral_public_key), private_key=local_secret_key, ) assert created_signature == id_nonce_signature diff --git a/tests/core/v5/test_packet_preparation.py b/tests/core/v5/test_packet_preparation.py index dae4e864..bc669321 100644 --- a/tests/core/v5/test_packet_preparation.py +++ b/tests/core/v5/test_packet_preparation.py @@ -5,7 +5,6 @@ import rlp from ddht.encryption import aesgcm_decrypt -from ddht.handshake_schemes import V4HandshakeScheme from ddht.tools.v5_strategies import ( enr_seq_st, id_nonce_st, @@ -22,6 +21,7 @@ MAGIC_SIZE, ZERO_NONCE, ) +from ddht.v5.handshake_schemes import SignatureInputs, V4HandshakeScheme from ddht.v5.messages import PingMessage, v5_registry from ddht.v5.packets import ( AuthHeader, @@ -206,9 +206,8 @@ def test_official_auth_response_encryption( secret_key, id_nonce, enr, auth_response_key, ephemeral_public_key, auth_cipher_text ): id_nonce_signature = V4HandshakeScheme.create_id_nonce_signature( - id_nonce=id_nonce, + signature_inputs=SignatureInputs(id_nonce, ephemeral_public_key), private_key=secret_key, - ephemeral_public_key=ephemeral_public_key, ) assert ( compute_encrypted_auth_response( @@ -327,8 +326,7 @@ def test_official_auth_header_packet_preparation( assert message.to_bytes() == encoded_message id_nonce_signature = V4HandshakeScheme.create_id_nonce_signature( - id_nonce=id_nonce, - ephemeral_public_key=ephemeral_public_key, + signature_inputs=SignatureInputs(id_nonce, ephemeral_public_key), private_key=local_private_key, ) diff --git a/tests/core/v5_1/test_client.py b/tests/core/v5_1/test_client.py index 56e7c39a..948a3e8a 100644 --- a/tests/core/v5_1/test_client.py +++ b/tests/core/v5_1/test_client.py @@ -39,7 +39,9 @@ async def test_client_send_pong(alice, bob, alice_client, bob_client): with trio.fail_after(2): async with alice.events.pong_sent.subscribe_and_wait(): async with bob.events.pong_received.subscribe_and_wait(): - await alice_client.send_pong(bob.endpoint, bob.node_id, request_id=1234) + await alice_client.send_pong( + bob.endpoint, bob.node_id, request_id=b"\x12" + ) @pytest.mark.trio @@ -67,7 +69,7 @@ async def test_client_send_found_nodes(alice, bob, alice_client, bob_client, enr async with alice.events.found_nodes_sent.subscribe_and_wait(): async with bob.events.found_nodes_received.subscribe() as subscription: await alice_client.send_found_nodes( - bob.endpoint, bob.node_id, enrs=enrs, request_id=1234, + bob.endpoint, bob.node_id, enrs=enrs, request_id=b"\x12", ) with trio.fail_after(2): @@ -107,7 +109,7 @@ async def test_client_send_talk_response(alice, bob, alice_client, bob_client): bob.endpoint, bob.node_id, payload=b"test-response", - request_id=1234, + request_id=b"\x12", ) @@ -122,7 +124,7 @@ async def test_client_send_register_topic(alice, bob, alice_client, bob_client): topic=b"unicornsrainbowsunicornsrainbows", enr=alice.enr, ticket=b"test-ticket", - request_id=1234, + request_id=b"\x12", ) @@ -136,7 +138,7 @@ async def test_client_send_ticket(alice, bob, alice_client, bob_client): bob.node_id, ticket=b"test-ticket", wait_time=600, - request_id=1234, + request_id=b"\x12", ) @@ -151,7 +153,7 @@ async def test_client_send_registration_confirmation( bob.endpoint, bob.node_id, topic=b"unicornsrainbowsunicornsrainbows", - request_id=1234, + request_id=b"\x12", ) @@ -164,7 +166,7 @@ async def test_client_send_topic_query(alice, bob, alice_client, bob_client): bob.endpoint, bob.node_id, topic=b"unicornsrainbowsunicornsrainbows", - request_id=1234, + request_id=b"\x12", ) @@ -173,8 +175,8 @@ async def test_client_send_topic_query(alice, bob, alice_client, bob_client): # @pytest.mark.trio async def test_client_request_response_ping_pong(alice, bob, alice_client, bob_client): - async with trio.open_nursery() as nursery: - async with bob.events.ping_received.subscribe() as subscription: + async with bob.events.ping_received.subscribe() as subscription: + async with trio.open_nursery() as nursery: async def _send_response(): ping = await subscription.receive() diff --git a/tests/core/v5_1/test_dispatcher.py b/tests/core/v5_1/test_dispatcher.py index 84c97f77..62c3318e 100644 --- a/tests/core/v5_1/test_dispatcher.py +++ b/tests/core/v5_1/test_dispatcher.py @@ -15,7 +15,7 @@ async def test_dispatcher_handles_incoming_envelopes(tester, driver, alice, bob) ): with trio.fail_after(1): async with alice.events.ping_received.subscribe_and_wait(): - await driver.recipient.send_ping(1234) + await driver.recipient.send_ping(b"\x12") @pytest.fixture @@ -38,7 +38,7 @@ async def test_dispatcher_bidirectional_communication( async with bob.events.ping_received.subscribe_and_wait(): await alice_dispatcher.send_message( OutboundMessage( - PingMessage(1234, alice.enr.sequence_number), + PingMessage(b"\x12", alice.enr.sequence_number), bob.endpoint, bob.node_id, ) @@ -59,7 +59,7 @@ async def test_dispatcher_handles_incoming_envelopes_with_multiple_sessions( with trio.fail_after(2): async with driver_a.initiator.events.packet_discarded.subscribe_and_wait(): async with alice.events.ping_received.subscribe_and_wait(): - await driver_b.initiator.send_ping(1234) + await driver_b.initiator.send_ping(b"\x12") @pytest.mark.trio @@ -71,14 +71,18 @@ async def test_dispatcher_send_message_with_existing_session( async with bob.events.ping_received.subscribe_and_wait(): await alice_dispatcher.send_message( OutboundMessage( - PingMessage(1234, alice.enr.sequence_number), bob.endpoint, bob.node_id, + PingMessage(b"\x12", alice.enr.sequence_number), + bob.endpoint, + bob.node_id, ) ) async with bob.events.ping_received.subscribe_and_wait(): await alice_dispatcher.send_message( OutboundMessage( - PingMessage(4321, alice.enr.sequence_number), bob.endpoint, bob.node_id, + PingMessage(b"\x34", alice.enr.sequence_number), + bob.endpoint, + bob.node_id, ) ) @@ -93,7 +97,7 @@ async def test_dispatcher_send_message_creates_session( async with alice.events.session_created.subscribe_and_wait(): await alice_dispatcher.send_message( OutboundMessage( - PingMessage(1234, alice.enr.sequence_number), + PingMessage(b"\x12", alice.enr.sequence_number), carol.endpoint, carol.node_id, ) @@ -112,10 +116,10 @@ async def test_dispatcher_subscribe_to_message_type(tester, alice, bob): async with tester.dispatcher_pair(alice, bob) as (alice_dispatcher, _): async with tester.dispatcher_pair(alice, carol): async with alice_dispatcher.subscribe(PingMessage) as ping_subscription: - await driver_a.recipient.send_ping(1234) - await driver_a.initiator.send_pong(1234) - await driver_b.initiator.send_ping(4321) - await driver_b.recipient.send_pong(4321) + await driver_a.recipient.send_ping(b"\x12") + await driver_a.initiator.send_pong(b"\x12") + await driver_b.initiator.send_ping(b"\x34") + await driver_b.recipient.send_pong(b"\x34") with trio.fail_after(1): ping_message_a = await ping_subscription.receive() @@ -141,15 +145,15 @@ async def test_dispatcher_subscribe_to_message_type_with_endpoint_filter( async with alice_dispatcher.subscribe( PingMessage, endpoint=carol.endpoint ) as ping_subscription: - await driver_a.recipient.send_ping(1234) - await driver_a.initiator.send_pong(1234) - await driver_b.initiator.send_ping(4321) - await driver_b.recipient.send_pong(4321) + await driver_a.recipient.send_ping(b"\x12") + await driver_a.initiator.send_pong(b"\x12") + await driver_b.initiator.send_ping(b"\x34") + await driver_b.recipient.send_pong(b"\x34") with trio.fail_after(1): ping_message_a = await ping_subscription.receive() - assert ping_message_a.message.request_id == 4321 + assert ping_message_a.message.request_id == b"\x34" assert ping_message_a.sender_node_id == carol.node_id @@ -169,15 +173,15 @@ async def test_dispatcher_subscribe_to_message_type_with_node_id_filter( async with alice_dispatcher.subscribe( PingMessage, node_id=carol.node_id ) as ping_subscription: - await driver_a.recipient.send_ping(1234) - await driver_a.initiator.send_pong(1234) - await driver_b.initiator.send_ping(4321) - await driver_b.recipient.send_pong(4321) + await driver_a.recipient.send_ping(b"\x12") + await driver_a.initiator.send_pong(b"\x12") + await driver_b.initiator.send_ping(b"\x34") + await driver_b.recipient.send_pong(b"\x34") with trio.fail_after(1): ping_message_a = await ping_subscription.receive() - assert ping_message_a.message.request_id == 4321 + assert ping_message_a.message.request_id == b"\x34" assert ping_message_a.sender_node_id == carol.node_id @@ -193,18 +197,20 @@ async def test_dispatcher_subscribe_request_response(tester, alice, bob): async with tester.dispatcher_pair(alice, bob) as (alice_dispatcher, _): async with tester.dispatcher_pair(alice, carol): request = OutboundMessage( - PingMessage(1234, alice.enr.sequence_number), bob.endpoint, bob.node_id, + PingMessage(b"\x12", alice.enr.sequence_number), + bob.endpoint, + bob.node_id, ) async with alice_dispatcher.subscribe_request( request, PongMessage ) as subscription: - await driver_a.initiator.send_ping(1234) - await driver_b.initiator.send_ping(4321) - await driver_b.recipient.send_pong(4321) - await driver_a.recipient.send_pong(1234) + await driver_a.initiator.send_ping(b"\x12") + await driver_b.initiator.send_ping(b"\x34") + await driver_b.recipient.send_pong(b"\x34") + await driver_a.recipient.send_pong(b"\x12") with trio.fail_after(1): response = await subscription.receive() assert response.sender_node_id == bob.node_id - assert response.message.request_id == 1234 + assert response.message.request_id == b"\x12" diff --git a/tests/core/v5_1/test_handshake_schemes.py b/tests/core/v5_1/test_handshake_schemes.py new file mode 100644 index 00000000..40624d42 --- /dev/null +++ b/tests/core/v5_1/test_handshake_schemes.py @@ -0,0 +1,394 @@ +from hashlib import sha256 + +from eth_keys.datatypes import NonRecoverableSignature, PrivateKey +from eth_utils import ValidationError, decode_hex, keccak +from hypothesis import given +from hypothesis import strategies as st +import pytest + +from ddht.handshake_schemes import ecdh_agree, hkdf_expand_and_extract +from ddht.tools.factories.v5_1 import HeaderFactory, WhoAreYouPacketFactory +from ddht.tools.v5_strategies import iv_st, node_id_st, private_key_st +from ddht.v5_1.constants import ID_NONCE_SIGNATURE_PREFIX, WHO_ARE_YOU_PACKET_SIZE +from ddht.v5_1.handshake_schemes import SignatureInputs, V4HandshakeScheme +from ddht.v5_1.packets import WhoAreYouPacket + +header_st = st.binary(min_size=12, max_size=12).map( + lambda aes_gcm_nonce: HeaderFactory( + flag=WhoAreYouPacket.flag, + aes_gcm_nonce=aes_gcm_nonce, + auth_data_size=WHO_ARE_YOU_PACKET_SIZE, + ) +) +who_are_you_st = st.tuples( + st.binary(min_size=16, max_size=16), st.integers(min_value=0, max_value=65536), +).map(lambda id_nonce_and_seq_num: WhoAreYouPacket(*id_nonce_and_seq_num)) + + +def test_handshake_key_generation(): + private_key, public_key = V4HandshakeScheme.create_handshake_key_pair() + V4HandshakeScheme.validate_uncompressed_public_key(public_key) + V4HandshakeScheme.validate_handshake_public_key(public_key) + assert PrivateKey(private_key).public_key.to_bytes() == public_key + + +@pytest.mark.parametrize("public_key", (b"\x01" * 64, b"\x02" * 64)) +def test_handshake_public_key_validation_valid(public_key): + V4HandshakeScheme.validate_handshake_public_key(public_key) + + +@pytest.mark.parametrize( + "public_key", + (b"", b"\x02" * 31, b"\x02" * 32, b"\x02" * 33, b"\x02" * 63, b"\x02" * 65), +) +def test_handshake_public_key_validation_invalid(public_key): + with pytest.raises(ValidationError): + V4HandshakeScheme.validate_handshake_public_key(public_key) + + +@given( + private_key=private_key_st, + iv=iv_st, + header=header_st, + who_are_you=who_are_you_st, + ephemeral_key=private_key_st, + recipient_node_id=node_id_st, +) +def test_id_nonce_signing( + private_key, iv, header, who_are_you, ephemeral_key, recipient_node_id +): + ephemeral_public_key = PrivateKey(ephemeral_key).public_key.to_bytes() + signature = V4HandshakeScheme.create_id_nonce_signature( + signature_inputs=SignatureInputs( + iv, header, who_are_you, ephemeral_public_key, recipient_node_id + ), + private_key=private_key, + ) + signature_object = NonRecoverableSignature(signature) + message_hash = sha256( + ID_NONCE_SIGNATURE_PREFIX + + iv + + header.to_wire_bytes() + + who_are_you.to_wire_bytes() + + ephemeral_public_key + + recipient_node_id + ).digest() + assert signature_object.verify_msg_hash( + message_hash, PrivateKey(private_key).public_key + ) + + +@given( + private_key=private_key_st, + iv=iv_st, + header=header_st, + who_are_you=who_are_you_st, + ephemeral_key=private_key_st, + recipient_node_id=node_id_st, +) +def test_valid_id_nonce_signature_validation( + private_key, iv, header, who_are_you, ephemeral_key, recipient_node_id, +): + ephemeral_public_key = PrivateKey(ephemeral_key).public_key.to_bytes() + signature = V4HandshakeScheme.create_id_nonce_signature( + signature_inputs=SignatureInputs( + iv=iv, + header=header, + who_are_you=who_are_you, + ephemeral_public_key=ephemeral_public_key, + recipient_node_id=recipient_node_id, + ), + private_key=private_key, + ) + public_key = PrivateKey(private_key).public_key.to_compressed_bytes() + V4HandshakeScheme.validate_id_nonce_signature( + signature_inputs=SignatureInputs( + iv, header, who_are_you, ephemeral_public_key, recipient_node_id + ), + signature=signature, + public_key=public_key, + ) + + +def test_invalid_id_nonce_signature_validation(): + iv = b"\xaa" * 16 + who_are_you = WhoAreYouPacketFactory() + header = HeaderFactory() + private_key = b"\xcc" * 32 + ephemeral_public_key = b"\xdd" * 64 + recipient_node_id = b"\xee" * 32 + + signature = V4HandshakeScheme.create_id_nonce_signature( + signature_inputs=SignatureInputs( + iv=iv, + header=header, + who_are_you=who_are_you, + ephemeral_public_key=ephemeral_public_key, + recipient_node_id=recipient_node_id, + ), + private_key=private_key, + ) + public_key = PrivateKey(private_key).public_key.to_compressed_bytes() + + different_public_key = PrivateKey(b"\x22" * 32).public_key.to_compressed_bytes() + different_iv = b"\x11" * 16 + different_header = HeaderFactory( + flag=WhoAreYouPacket.flag, auth_data_size=WHO_ARE_YOU_PACKET_SIZE, + ) + different_who_are_you = WhoAreYouPacketFactory() + different_ephemeral_public_key = b"\x33" * 64 + different_recipient_node_id = b"\x44" * 32 + + assert different_public_key != public_key + assert different_iv != iv + assert different_header != header + assert different_who_are_you != who_are_you + assert different_recipient_node_id != recipient_node_id + + # wrong public_key + with pytest.raises(ValidationError): + V4HandshakeScheme.validate_id_nonce_signature( + signature_inputs=SignatureInputs( + iv, header, who_are_you, ephemeral_public_key, recipient_node_id + ), + signature=signature, + public_key=different_public_key, + ) + + # wrong iv + with pytest.raises(ValidationError): + V4HandshakeScheme.validate_id_nonce_signature( + signature_inputs=SignatureInputs( + different_iv, + header, + who_are_you, + ephemeral_public_key, + recipient_node_id, + ), + signature=signature, + public_key=public_key, + ) + + # wrong header + with pytest.raises(ValidationError): + V4HandshakeScheme.validate_id_nonce_signature( + signature_inputs=SignatureInputs( + iv, + different_header, + who_are_you, + ephemeral_public_key, + recipient_node_id, + ), + signature=signature, + public_key=public_key, + ) + + # wrong who-are-you + with pytest.raises(ValidationError): + V4HandshakeScheme.validate_id_nonce_signature( + signature_inputs=SignatureInputs( + iv, + header, + different_who_are_you, + ephemeral_public_key, + recipient_node_id, + ), + signature=signature, + public_key=public_key, + ) + + # wrong ephemeral_public_key + with pytest.raises(ValidationError): + V4HandshakeScheme.validate_id_nonce_signature( + signature_inputs=SignatureInputs( + iv, + header, + who_are_you, + different_ephemeral_public_key, + recipient_node_id, + ), + signature=signature, + public_key=public_key, + ) + + # wrong recipient_node_id + with pytest.raises(ValidationError): + V4HandshakeScheme.validate_id_nonce_signature( + signature_inputs=SignatureInputs( + iv, + header, + who_are_you, + ephemeral_public_key, + different_recipient_node_id, + ), + signature=signature, + public_key=public_key, + ) + + +@given( + initiator_private_key=private_key_st, + recipient_private_key=private_key_st, + iv=iv_st, + header=header_st, + who_are_you=who_are_you_st, +) +def test_session_key_derivation( + initiator_private_key, recipient_private_key, iv, header, who_are_you +): + initiator_private_key_object = PrivateKey(initiator_private_key) + recipient_private_key_object = PrivateKey(recipient_private_key) + + initiator_public_key = initiator_private_key_object.public_key.to_bytes() + recipient_public_key = recipient_private_key_object.public_key.to_bytes() + + initiator_node_id = keccak(initiator_private_key_object.public_key.to_bytes()) + recipient_node_id = keccak(recipient_private_key_object.public_key.to_bytes()) + + challenge_data = iv + header.to_wire_bytes() + who_are_you.to_wire_bytes() + + initiator_session_keys = V4HandshakeScheme.compute_session_keys( + local_private_key=initiator_private_key, + remote_public_key=recipient_public_key, + local_node_id=initiator_node_id, + remote_node_id=recipient_node_id, + salt=challenge_data, + is_locally_initiated=True, + ) + recipient_session_keys = V4HandshakeScheme.compute_session_keys( + local_private_key=recipient_private_key, + remote_public_key=initiator_public_key, + local_node_id=recipient_node_id, + remote_node_id=initiator_node_id, + salt=challenge_data, + is_locally_initiated=False, + ) + + assert ( + initiator_session_keys.auth_response_key + == recipient_session_keys.auth_response_key + ) + assert ( + initiator_session_keys.encryption_key == recipient_session_keys.decryption_key + ) + assert ( + initiator_session_keys.decryption_key == recipient_session_keys.encryption_key + ) + + +@pytest.mark.parametrize( + ["local_secret_key", "remote_public_key", "shared_secret_key"], + [ + [ + decode_hex( + "0xfb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736" + ), + decode_hex( + "0x9961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231503061ac4aaee666073d" # noqa: E501 + "7e5bc2c80c3f5c5b500c1cb5fd0a76abbb6b675ad157" + ), + decode_hex( + "0x033b11a2a1f214567e1537ce5e509ffd9b21373247f2a3ff6841f4976f53165e7e" + ), + ] + ], +) +def test_official_key_agreement(local_secret_key, remote_public_key, shared_secret_key): + assert ecdh_agree(local_secret_key, remote_public_key) == shared_secret_key + + +@pytest.mark.parametrize( + [ + "secret", + "initiator_node_id", + "recipient_node_id", + "id_nonce", + "initiator_key", + "recipient_key", + "auth_response_key", + ], + [ + [ + decode_hex( + "0x02a77e3aa0c144ae7c0a3af73692b7d6e5b7a2fdc0eda16e8d5e6cb0d08e88dd04" + ), + decode_hex( + "0xa448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7" + ), + decode_hex( + "0x885bba8dfeddd49855459df852ad5b63d13a3fae593f3f9fa7e317fd43651409" + ), + decode_hex( + "0x0101010101010101010101010101010101010101010101010101010101010101" + ), + decode_hex("0x238d8b50e4363cf603a48c6cc3542967"), + decode_hex("0xbebc0183484f7e7ca2ac32e3d72c8891"), + decode_hex("0xe987ad9e414d5b4f9bfe4ff1e52f2fae"), + ] + ], +) +def test_official_key_derivation( + secret, + initiator_node_id, + recipient_node_id, + id_nonce, + initiator_key, + recipient_key, + auth_response_key, +): + derived_keys = hkdf_expand_and_extract( + secret, initiator_node_id, recipient_node_id, id_nonce + ) + assert derived_keys[0] == initiator_key + assert derived_keys[1] == recipient_key + assert derived_keys[2] == auth_response_key + + +@pytest.mark.skip(reason="NO UPDATED TEST VECTORS") +@pytest.mark.parametrize( + [ + "iv", + "id_nonce", + "ephemeral_public_key", + "local_secret_key", + "recipient_node_id", + "id_nonce_signature", + ], + [ + [ + decode_hex("0x0011223344556677889900aabbccddeeff"), + decode_hex( + "0xa77e3aa0c144ae7c0a3af73692b7d6e5b7a2fdc0eda16e8d5e6cb0d08e88dd04" + ), + decode_hex( + "0x9961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231503061ac4aaee666" + "073d7e5bc2c80c3f5c5b500c1cb5fd0a76abbb6b675ad157" + ), + decode_hex( + "0xfb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736" + ), + decode_hex( + "0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9", + ), + decode_hex( + "0x7fa74e136e54473053134d7d66f29d927756dd86b36f745f256f0f87e08503cb70543290674dd9e9" + "5297a7e2f2983d83d881396bad3688612773fde0e586668d" + ), + ] + ], +) +def test_official_id_nonce_signature( + iv, + id_nonce, + ephemeral_public_key, + local_secret_key, + recipient_node_id, + id_nonce_signature, +): + created_signature = V4HandshakeScheme.create_id_nonce_signature( + signature_inputs=SignatureInputs( + iv, id_nonce, ephemeral_public_key, recipient_node_id + ), + private_key=local_secret_key, + ) + assert created_signature == id_nonce_signature diff --git a/tests/core/v5_1/test_packet_encoding.py b/tests/core/v5_1/test_packet_encoding.py index 6e396a0c..8a472f12 100644 --- a/tests/core/v5_1/test_packet_encoding.py +++ b/tests/core/v5_1/test_packet_encoding.py @@ -14,18 +14,17 @@ def test_message_packet_encoding(): initiator_key = b"\x01" * 16 - nonce = b"\x02" * 12 + aes_gcm_nonce = b"\x02" * 12 source_node_id = b"\x03" * 32 dest_node_id = b"\x04" * 32 - message = PingMessage(1, 0) - auth_data = MessagePacket(b"\x05" * 12) + message = PingMessage(b"\x01", 0) + auth_data = MessagePacket(source_node_id) packet = Packet.prepare( - nonce=nonce, + aes_gcm_nonce=aes_gcm_nonce, initiator_key=initiator_key, message=message, auth_data=auth_data, - source_node_id=source_node_id, dest_node_id=dest_node_id, ) packet_wire_bytes = packet.to_wire_bytes() @@ -36,20 +35,16 @@ def test_message_packet_encoding(): def test_who_are_you_packet_encoding(): initiator_key = b"\x01" * 16 - nonce = b"\x02" * 12 - source_node_id = b"\x03" * 32 + aes_gcm_nonce = b"\x02" * 12 dest_node_id = b"\x04" * 32 - message = PingMessage(1, 0) - auth_data = WhoAreYouPacket( - request_nonce=b"\x05" * 12, id_nonce=b"\x06" * 32, enr_sequence_number=0x07 - ) + message = PingMessage(b"\x01", 0) + auth_data = WhoAreYouPacket(id_nonce=b"\x06" * 16, enr_sequence_number=0x07) packet = Packet.prepare( - nonce=nonce, + aes_gcm_nonce=aes_gcm_nonce, initiator_key=initiator_key, message=message, auth_data=auth_data, - source_node_id=source_node_id, dest_node_id=dest_node_id, ) packet_wire_bytes = packet.to_wire_bytes() @@ -63,13 +58,13 @@ def test_who_are_you_packet_encoding(): ) def test_handshake_packet_encoding(enr): initiator_key = b"\x01" * 16 - nonce = b"\x02" * 12 + aes_gcm_nonce = b"\x02" * 12 source_node_id = b"\x03" * 32 dest_node_id = b"\x04" * 32 - message = PingMessage(1, 0) + message = PingMessage(b"\x01", 0) auth_data = HandshakePacket( auth_data_head=HandshakeHeader( - version=1, signature_size=64, ephemeral_key_size=33, + source_node_id=source_node_id, signature_size=64, ephemeral_key_size=33, ), id_signature=b"\x05" * 64, ephemeral_public_key=b"\x06" * 33, @@ -77,11 +72,10 @@ def test_handshake_packet_encoding(enr): ) packet = Packet.prepare( - nonce=nonce, + aes_gcm_nonce=aes_gcm_nonce, initiator_key=initiator_key, message=message, auth_data=auth_data, - source_node_id=source_node_id, dest_node_id=dest_node_id, ) packet_wire_bytes = packet.to_wire_bytes() diff --git a/tests/core/v5_1/test_session.py b/tests/core/v5_1/test_session.py index 1c141f4b..b12a8930 100644 --- a/tests/core/v5_1/test_session.py +++ b/tests/core/v5_1/test_session.py @@ -3,7 +3,7 @@ from eth_utils import int_to_big_endian import pytest from rlp.exceptions import DecodingError, DeserializationError -from rlp.sedes import big_endian_int +from rlp.sedes import binary import trio from ddht.base_message import BaseMessage @@ -65,11 +65,11 @@ async def test_session_message_sending_during_handshake(driver): assert driver.recipient.session.is_before_handshake # initiate the handshake - await driver.initiator.send_ping(0) + await driver.initiator.send_ping(b"\x00") # send first message before initiation packet is transmitted # we cannot send a message from the recipient until they have the remote node id - await driver.initiator.send_ping(1) + await driver.initiator.send_ping(b"\x01") assert driver.initiator.session.is_during_handshake assert driver.recipient.session.is_before_handshake @@ -81,8 +81,8 @@ async def test_session_message_sending_during_handshake(driver): assert driver.recipient.session.is_during_handshake # send second message after initiation packet is transmitted - await driver.initiator.send_ping(2) - await driver.recipient.send_ping(3) + await driver.initiator.send_ping(b"\x02") + await driver.recipient.send_ping(b"\x03") # step the handshake forward await driver.transmit_one(driver.recipient) @@ -91,8 +91,8 @@ async def test_session_message_sending_during_handshake(driver): assert driver.recipient.session.is_during_handshake # send third message after initiation packet is transmitted - await driver.initiator.send_ping(4) - await driver.recipient.send_ping(5) + await driver.initiator.send_ping(b"\x04") + await driver.recipient.send_ping(b"\x05") # step the handshake forward await driver.transmit_one(driver.initiator) @@ -110,13 +110,13 @@ async def test_session_message_sending_during_handshake(driver): ping_3 = await driver.initiator.next_message() ping_5 = await driver.initiator.next_message() - assert ping_0.message.request_id == 0 - assert ping_1.message.request_id == 1 - assert ping_2.message.request_id == 2 - assert ping_4.message.request_id == 4 + assert ping_0.message.request_id == b"\x00" + assert ping_1.message.request_id == b"\x01" + assert ping_2.message.request_id == b"\x02" + assert ping_4.message.request_id == b"\x04" - assert ping_3.message.request_id == 3 - assert ping_5.message.request_id == 5 + assert ping_3.message.request_id == b"\x03" + assert ping_5.message.request_id == b"\x05" @pytest.mark.trio @@ -124,13 +124,13 @@ async def test_session_message_sending_after_handshake(driver): await driver.handshake() async with driver.transmit(): - await driver.initiator.send_ping(1234) + await driver.initiator.send_ping(b"\x12") ping_message = await driver.recipient.next_message() - assert ping_message.message.request_id == 1234 + assert ping_message.message.request_id == b"\x12" - await driver.recipient.send_pong(1234) + await driver.recipient.send_pong(b"\x12") pong_message = await driver.initiator.next_message() - assert pong_message.message.request_id == 1234 + assert pong_message.message.request_id == b"\x12" @pytest.mark.trio @@ -142,7 +142,7 @@ async def test_session_unexpected_packets(driver): assert driver.recipient.session.is_before_handshake # initiate the handshake - await driver.initiator.send_ping(1234) + await driver.initiator.send_ping(b"\x12") assert driver.initiator.session.is_during_handshake assert driver.recipient.session.is_before_handshake @@ -153,9 +153,7 @@ async def test_session_unexpected_packets(driver): # since the recipient has not yet sent the `WhoAreYouPacket` async with driver.recipient.events.packet_discarded.subscribe_and_wait(): await driver.send_packet( - PacketFactory.who_are_you( - source_node_id=initiator.node_id, dest_node_id=recipient.node_id, - ) + PacketFactory.who_are_you(dest_node_id=recipient.node_id,) ) async with driver.recipient.events.packet_discarded.subscribe_and_wait(): await driver.send_packet( @@ -191,9 +189,7 @@ async def test_session_unexpected_packets(driver): # reason for the initiator to send such a packet. async with driver.recipient.events.packet_discarded.subscribe_and_wait(): await driver.send_packet( - PacketFactory.who_are_you( - source_node_id=initiator.node_id, dest_node_id=recipient.node_id, - ) + PacketFactory.who_are_you(dest_node_id=recipient.node_id,) ) # The initiator should discard a HandshakePacket since there is no @@ -216,9 +212,7 @@ async def test_session_unexpected_packets(driver): # reason for the initiator to send such a packet. async with driver.recipient.events.packet_discarded.subscribe_and_wait(): await driver.send_packet( - PacketFactory.who_are_you( - source_node_id=initiator.node_id, dest_node_id=recipient.node_id, - ) + PacketFactory.who_are_you(dest_node_id=recipient.node_id,) ) # The recipient should buffer any message packets it receives at this @@ -226,9 +220,9 @@ async def test_session_unexpected_packets(driver): # the initiator can now have valid session keys. await driver.send_packet( PacketFactory.message( - nonce=driver.initiator.session.get_encryption_nonce(), + aes_gcm_nonce=driver.initiator.session.get_encryption_nonce(), initiator_key=driver.initiator.session.keys.encryption_key, - message=PingMessage(4321, initiator.enr.sequence_number), + message=PingMessage(b"\x34", initiator.enr.sequence_number), source_node_id=initiator.node_id, dest_node_id=recipient.node_id, ) @@ -254,16 +248,16 @@ async def test_session_unexpected_packets(driver): initiation_ping = await driver.recipient.next_message() out_of_order_ping = await driver.recipient.next_message() - assert initiation_ping.message.request_id == 1234 - assert out_of_order_ping.message.request_id == 4321 + assert initiation_ping.message.request_id == b"\x12" + assert out_of_order_ping.message.request_id == b"\x34" class BadMessage(BaseMessage): - fields = (("request_id", big_endian_int),) + fields = (("request_id", binary),) def __init__(self, message_type, request_id=None): if request_id is None: - request_id = secrets.randbits(32) + request_id = int_to_big_endian(secrets.randbits(32)) self.message_type = message_type super().__init__(request_id=request_id) @@ -276,7 +270,7 @@ async def test_session_message_mismatched_rlp(driver): with pytest.raises(DeserializationError): await driver.send_packet( PacketFactory.message( - nonce=driver.initiator.session.get_encryption_nonce(), + aes_gcm_nonce=driver.initiator.session.get_encryption_nonce(), initiator_key=driver.initiator.session.keys.encryption_key, message=BadMessage(1), source_node_id=driver.initiator.node.node_id, @@ -293,7 +287,7 @@ async def test_session_message_unknown_message_type(driver): with pytest.raises(KeyError): await driver.send_packet( PacketFactory.message( - nonce=driver.initiator.session.get_encryption_nonce(), + aes_gcm_nonce=driver.initiator.session.get_encryption_nonce(), initiator_key=driver.initiator.session.keys.encryption_key, message=BadMessage(255), source_node_id=driver.initiator.node.node_id, @@ -303,7 +297,7 @@ async def test_session_message_unknown_message_type(driver): class GarbledMessage(BaseMessage): - fields = (("request_id", big_endian_int),) + fields = (("request_id", binary),) def __init__(self, message_type, message_bytes): self.message_type = message_type @@ -321,7 +315,7 @@ async def test_session_invalid_rlp(driver): with pytest.raises(DecodingError): await driver.send_packet( PacketFactory.message( - nonce=driver.initiator.session.get_encryption_nonce(), + aes_gcm_nonce=driver.initiator.session.get_encryption_nonce(), initiator_key=driver.initiator.session.keys.encryption_key, message=GarbledMessage(1, b"\xff\xff\xff"), source_node_id=driver.initiator.node.node_id, @@ -379,9 +373,7 @@ async def test_session_last_message_received_at(driver, autojump_clock): recipient_node_id = driver.recipient.node.node_id await driver.send_packet( - PacketFactory.who_are_you( - source_node_id=initiator_node_id, dest_node_id=recipient_node_id, - ) + PacketFactory.who_are_you(dest_node_id=recipient_node_id,) ) await trio.sleep(0.01) # let the packet process @@ -389,9 +381,7 @@ async def test_session_last_message_received_at(driver, autojump_clock): assert recipient.last_message_received_at < anchor await driver.send_packet( - PacketFactory.who_are_you( - source_node_id=recipient_node_id, dest_node_id=initiator_node_id, - ) + PacketFactory.who_are_you(dest_node_id=initiator_node_id,) ) await trio.sleep(0.01) # let the packet process diff --git a/tests/core/v5_1/test_specification_fixtures.py b/tests/core/v5_1/test_specification_fixtures.py new file mode 100644 index 00000000..b0d96720 --- /dev/null +++ b/tests/core/v5_1/test_specification_fixtures.py @@ -0,0 +1,370 @@ +import io + +from eth_keys import keys +from eth_utils import decode_hex, to_int +import pytest + +from ddht.v5_1.constants import HEADER_PACKET_SIZE, WHO_ARE_YOU_PACKET_SIZE +from ddht.v5_1.handshake_schemes import V4HandshakeScheme +from ddht.v5_1.messages import PingMessage, decode_message +from ddht.v5_1.packets import ( + HandshakeHeader, + HandshakePacket, + Header, + MessagePacket, + WhoAreYouPacket, + decode_packet, +) + +NODE_KEY_A = decode_hex( + "0xeef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f" +) +NODE_KEY_B = decode_hex( + "0x66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628" +) + + +PACKET_DECODING_FIXTURES = ( + # ping message packet + { + "src-node-id": "0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb", + "dest-node-id": "0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9", + "read-key": "0x00000000000000000000000000000000", + "type": "decoding", + "nonce": "0xffffffffffffffffffffffff", + "packet": { + "type": "message", + "message": {"req-id": "0x00000001", "enr-seq": "0x2"}, + }, + "encoded": ( + "00000000000000000000000000000000088b3d4342774649325f313964a39e55" + "ea96c005ad52be8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3" + "4c4f53245d08dab84102ed931f66d1492acb308fa1c6715b9d139b81acbdcc" + ), + }, + # who are you packet + { + "src-node-id": "0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb", + "dest-node-id": "0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9", + "type": "decoding", + "packet": { + "type": "whoareyou", + "iv": "0x00000000000000000000000000000000", + "authdata": "0x0102030405060708090a0b0c0d0e0f100000000000000000", + "request-nonce": "0x0102030405060708090a0b0c", + "id-nonce": "0x0102030405060708090a0b0c0d0e0f10", + "enr-seq": "0x0", + }, + "encoded": ( + "00000000000000000000000000000000088b3d434277464933a1ccc59f5967ad" + "1d6035f15e528627dde75cd68292f9e6c27d6b66c8100a873fcbaed4e16b8d" + ), + }, + # handshake packet (ping message) (without ENR) + { + "read-key": "0x4f9fac6de7567d1e3b1241dffe90f662", + "src-node-id": "0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb", + "dest-node-id": "0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9", + "nonce": "0xffffffffffffffffffffffff", + "type": "decoding", + "packet": { + "type": "handshake", + "message": {"req-id": "0x00000001", "enr-seq": "0x1"}, + }, + "handshake-inputs": { + "whoareyou": { + "challenge-data": ( + "0x" + "0000000000000000000000000000000064697363763500010101020304050607" + "08090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000001" + ), + "authdata": "0x0102030405060708090a0b0c0d0e0f100000000000000001", + "request-nonce": "0x0102030405060708090a0b0c", + "id-nonce": "0x0102030405060708090a0b0c0d0e0f10", + "enr-seq": "0x1", + }, + "ephemeral-key": "0x0288ef00023598499cb6c940146d050d2b1fb914198c327f76aad590bead68b6", + "ephemeral-pubkey": "0x039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5", # noqa: E501 + }, + "encoded": ( + "00000000000000000000000000000000088b3d4342774649305f313964a39e55" + "ea96c005ad521d8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3" + "4c4f53245d08da4bb252012b2cba3f4f374a90a75cff91f142fa9be3e0a5f3ef" + "268ccb9065aeecfd67a999e7fdc137e062b2ec4a0eb92947f0d9a74bfbf44dfb" + "a776b21301f8b65efd5796706adff216ab862a9186875f9494150c4ae06fa4d1" + "f0396c93f215fa4ef524f1eadf5f0f4126b79336671cbcf7a885b1f8bd2a5d83" + "9cf8" + ), + }, + # handshake packet (ping message) (with ENR) + { + "read-key": "0x53b1c075f41876423154e157470c2f48", + "src-node-id": "0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb", + "dest-node-id": "0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9", + "nonce": "0xffffffffffffffffffffffff", + "type": "decoding", + "packet": { + "type": "handshake", + "message": {"req-id": "0x00000001", "enr-seq": "0x1"}, + }, + "handshake-inputs": { + "whoareyou": { + "challenge-data": ( + "0x" + "000000000000000000000000000000006469736376350001010102030405060" + "708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000" + ), + "request-nonce": "0x0102030405060708090a0b0c", + "id-nonce": "0x0102030405060708090a0b0c0d0e0f10", + "enr-seq": "0x0", + }, + "ephemeral-key": "0x0288ef00023598499cb6c940146d050d2b1fb914198c327f76aad590bead68b6", + "ephemeral-pubkey": "0x039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5", # noqa: E501 + }, + "encoded": ( + "00000000000000000000000000000000088b3d4342774649305f313964a39e55" + "ea96c005ad539c8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3" + "4c4f53245d08da4bb23698868350aaad22e3ab8dd034f548a1c43cd246be9856" + "2fafa0a1fa86d8e7a3b95ae78cc2b988ded6a5b59eb83ad58097252188b902b2" + "1481e30e5e285f19735796706adff216ab862a9186875f9494150c4ae06fa4d1" + "f0396c93f215fa4ef524e0ed04c3c21e39b1868e1ca8105e585ec17315e755e6" + "cfc4dd6cb7fd8e1a1f55e49b4b5eb024221482105346f3c82b15fdaae36a3bb1" + "2a494683b4a3c7f2ae41306252fed84785e2bbff3b022812d0882f06978df84a" + "80d443972213342d04b9048fc3b1d5fcb1df0f822152eced6da4d3f6df27e70e" + "4539717307a0208cd208d65093ccab5aa596a34d7511401987662d8cf62b1394" + "71" + ), + }, +) + + +@pytest.mark.parametrize("fixture", PACKET_DECODING_FIXTURES) +def test_v51_specification_packet_decoding_fixtures(fixture): + if fixture["type"] == "decoding": + do_encoding_fixture_test(fixture) + else: + raise Exception("Not supported") + + +def do_handshake_packet_fixture_decoding_test(fixture): + source_node_id = decode_hex(fixture["src-node-id"]) + dest_node_id = decode_hex(fixture["dest-node-id"]) + encoded_packet = decode_hex(fixture["encoded"]) + ping_enr_seq = to_int(hexstr=fixture["packet"]["message"]["enr-seq"]) + who_are_you_enr_seq = to_int( + hexstr=fixture["handshake-inputs"]["whoareyou"]["enr-seq"] + ) + + if who_are_you_enr_seq == ping_enr_seq and who_are_you_enr_seq != 0: + should_have_record = False + else: + should_have_record = True + + # ephemeral_private_key = decode_hex(fixture['handshake-inputs']['ephemeral-key']) + ephemeral_public_key = decode_hex(fixture["handshake-inputs"]["ephemeral-pubkey"]) + # ephemeral_private_key = decode_hex(fixture["handshake-inputs"]["ephemeral-key"]) + + # request_nonce = decode_hex(fixture['handshake-inputs']['whoareyou']['request-nonce']) + challenge_data = decode_hex( + fixture["handshake-inputs"]["whoareyou"]["challenge-data"] + ) + masking_iv, static_header, who_are_you = extract_challenge_data(challenge_data) + + id_nonce = decode_hex(fixture["handshake-inputs"]["whoareyou"]["id-nonce"]) + assert who_are_you.id_nonce == id_nonce + + aes_gcm_nonce = decode_hex(fixture["nonce"]) + # TODO: why doesn't this match + # assert static_header.aes_gcm_nonce == aes_gcm_nonce + + signature_inputs = V4HandshakeScheme.signature_inputs_cls( + iv=masking_iv, + header=static_header, + who_are_you=WhoAreYouPacket(id_nonce, who_are_you_enr_seq), + ephemeral_public_key=ephemeral_public_key, + recipient_node_id=dest_node_id, + ) + + id_nonce_signature = V4HandshakeScheme.create_id_nonce_signature( + signature_inputs=signature_inputs, private_key=NODE_KEY_A, + ) + + packet = decode_packet(encoded_packet, dest_node_id) + expected_auth_data = HandshakePacket( + auth_data_head=HandshakeHeader(source_node_id, 64, 33), + id_signature=id_nonce_signature, + ephemeral_public_key=ephemeral_public_key, + record=packet.auth_data.record, + ) + + assert expected_auth_data == packet.auth_data + assert packet.header.aes_gcm_nonce == aes_gcm_nonce + + if should_have_record: + assert packet.auth_data.record is not None + assert packet.auth_data.record.node_id == source_node_id + else: + assert packet.auth_data.record is None + + expected_message = PingMessage( + request_id=decode_hex(fixture["packet"]["message"]["req-id"]), + enr_seq=to_int(hexstr=fixture["packet"]["message"]["enr-seq"]), + ) + actual_message = decode_message( + decryption_key=decode_hex(fixture["read-key"]), + aes_gcm_nonce=aes_gcm_nonce, + message_cipher_text=packet.message_cipher_text, + authenticated_data=packet.challenge_data, + ) + assert expected_message == actual_message + + +def do_encoding_fixture_test(fixture): + if fixture["packet"]["type"] == "whoareyou": + do_who_are_you_packet_fixture_decoding_test(fixture) + elif fixture["packet"]["type"] == "message": + do_message_packet_fixture_decoding_test(fixture) + elif fixture["packet"]["type"] == "handshake": + do_handshake_packet_fixture_decoding_test(fixture) + else: + raise Exception("Not supported") + + +def do_who_are_you_packet_fixture_decoding_test(fixture): + dest_node_id = decode_hex(fixture["dest-node-id"]) + expected_auth_data = WhoAreYouPacket( + id_nonce=decode_hex(fixture["packet"]["id-nonce"]), + enr_sequence_number=to_int(hexstr=fixture["packet"]["enr-seq"]), + ) + encoded_packet = decode_hex(fixture["encoded"]) + aes_gcm_nonce = decode_hex(fixture["packet"]["request-nonce"]) + + packet = decode_packet(encoded_packet, dest_node_id) + + assert packet.auth_data == expected_auth_data + assert packet.header.aes_gcm_nonce == aes_gcm_nonce + + +def do_message_packet_fixture_decoding_test(fixture): + dest_node_id = decode_hex(fixture["dest-node-id"]) + expected_auth_data = MessagePacket( + source_node_id=decode_hex(fixture["src-node-id"]), + ) + expected_message = PingMessage( + request_id=decode_hex(fixture["packet"]["message"]["req-id"]), + enr_seq=to_int(hexstr=fixture["packet"]["message"]["enr-seq"]), + ) + encoded_packet = decode_hex(fixture["encoded"]) + packet = decode_packet(encoded_packet, dest_node_id) + assert packet.auth_data == expected_auth_data + + aes_gcm_nonce = decode_hex(fixture["nonce"]) + + actual_message = decode_message( + decryption_key=decode_hex(fixture["read-key"]), + aes_gcm_nonce=aes_gcm_nonce, + message_cipher_text=packet.message_cipher_text, + authenticated_data=packet.challenge_data, + ) + assert actual_message == expected_message + + +ID_NONCE_SIGNING_FIXTURES = ( + { + "static-key": "0xfb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736", + "challenge-data": ( + "0x" + "0000000000000000000000000000000064697363763500010101020304050607" + "08090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000" + ), + "ephemeral-pubkey": "0x039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231", + "node-id-B": "0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9", + "signature": ( + "0x" + "94852a1e2318c4e5e9d422c98eaf19d1d90d876b29cd06ca7cb7546d0fff7b48" + "4fe86c09a064fe72bdbef73ba8e9c34df0cd2b53e9d65528c2c7f336d5dfc6e6" + ), + }, +) + + +def extract_challenge_data(challenge_data): + stream = io.BytesIO(challenge_data) + + masking_iv = stream.read(16) + static_header = Header.from_wire_bytes(stream.read(HEADER_PACKET_SIZE)) + who_are_you = WhoAreYouPacket.from_wire_bytes(stream.read(WHO_ARE_YOU_PACKET_SIZE)) + + assert stream.read() == b"" + + return masking_iv, static_header, who_are_you + + +@pytest.mark.parametrize("fixture", ID_NONCE_SIGNING_FIXTURES) +def test_v51_specification_id_nonce_signing_fixtures(fixture): + private_key = decode_hex(fixture["static-key"]) + public_key = keys.PrivateKey(private_key).public_key.to_compressed_bytes() + + challenge_data = decode_hex(fixture["challenge-data"]) + masking_iv, static_header, who_are_you = extract_challenge_data(challenge_data) + + ephemeral_public_key = decode_hex(fixture["ephemeral-pubkey"]) + recipient_node_id = decode_hex(fixture["node-id-B"]) + expected_signature = decode_hex(fixture["signature"]) + + signature_inputs = V4HandshakeScheme.signature_inputs_cls( + iv=masking_iv, + header=static_header, + who_are_you=who_are_you, + ephemeral_public_key=ephemeral_public_key, + recipient_node_id=recipient_node_id, + ) + actual_signature = V4HandshakeScheme.create_id_nonce_signature( + signature_inputs=signature_inputs, private_key=private_key, + ) + assert actual_signature == expected_signature + + V4HandshakeScheme.validate_id_nonce_signature( + signature_inputs=signature_inputs, + signature=expected_signature, + public_key=public_key, + ) + + +KEY_DERIVATION_FIXTURES = ( + { + "ephemeral-key": "0xfb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736", + "dest-pubkey": "0x0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91", + "node-id-a": "0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb", + "node-id-b": "0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9", + "challenge-data": ( + "0x" + "000000000000000000000000000000006469736376350001010102030405060" + "708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000" + ), + "initiator-key": "0xdccc82d81bd610f4f76d3ebe97a40571", + "recipient-key": "0xac74bb8773749920b0d3a8881c173ec5", + }, +) + + +@pytest.mark.parametrize("fixture", KEY_DERIVATION_FIXTURES) +def test_v51_specification_key_derivation(fixture): + ephemeral_private_key = decode_hex(fixture["ephemeral-key"]) + dest_public_key = decode_hex(fixture["dest-pubkey"]) + node_id_A = decode_hex(fixture["node-id-a"]) + node_id_B = decode_hex(fixture["node-id-b"]) + challenge_data = decode_hex(fixture["challenge-data"]) + initiator_key = decode_hex(fixture["initiator-key"]) + recipient_key = decode_hex(fixture["recipient-key"]) + + session_keys = V4HandshakeScheme.compute_session_keys( + local_private_key=ephemeral_private_key, + remote_public_key=dest_public_key, + local_node_id=node_id_A, + remote_node_id=node_id_B, + salt=challenge_data, + is_locally_initiated=True, + ) + assert session_keys.encryption_key == initiator_key + assert session_keys.decryption_key == recipient_key