Skip to content

Commit

Permalink
V5.1 Specification updates
Browse files Browse the repository at this point in the history
  • Loading branch information
pipermerriam committed Oct 5, 2020
1 parent b181822 commit fe71b79
Show file tree
Hide file tree
Showing 29 changed files with 1,342 additions and 427 deletions.
25 changes: 12 additions & 13 deletions ddht/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ddht.base_message import BaseMessage
from ddht.boot_info import BootInfo
from ddht.typing import JSON, IDNonce, SessionKeys
from ddht.typing import JSON, SessionKeys

TAddress = TypeVar("TAddress", bound="AddressAPI")

Expand Down Expand Up @@ -177,8 +177,12 @@ def iter_all_random(self) -> Iterator[NodeID]:
...


class HandshakeSchemeAPI(ABC):
TSignatureInputs = TypeVar("TSignatureInputs")


class HandshakeSchemeAPI(ABC, Generic[TSignatureInputs]):
identity_scheme: Type[IdentitySchemeAPI]
signature_inputs_cls: Type[TSignatureInputs]

#
# Handshake
Expand All @@ -204,7 +208,7 @@ def compute_session_keys(
remote_public_key: bytes,
local_node_id: NodeID,
remote_node_id: NodeID,
id_nonce: IDNonce,
salt: bytes,
is_locally_initiated: bool,
) -> SessionKeys:
"""Compute the symmetric session keys."""
Expand All @@ -213,20 +217,15 @@ def compute_session_keys(
@classmethod
@abstractmethod
def create_id_nonce_signature(
cls, *, id_nonce: IDNonce, ephemeral_public_key: bytes, private_key: bytes
cls, *, signature_inputs: TSignatureInputs, private_key: bytes,
) -> bytes:
"""Sign an id nonce received during handshake."""
...

@classmethod
@abstractmethod
def validate_id_nonce_signature(
cls,
*,
id_nonce: IDNonce,
ephemeral_public_key: bytes,
signature: bytes,
public_key: bytes,
cls, *, signature_inputs: TSignatureInputs, signature: bytes, public_key: bytes,
) -> None:
"""Validate the id nonce signature received from a peer."""
...
Expand All @@ -235,7 +234,7 @@ def validate_id_nonce_signature(
# https://github.com/python/mypy/issues/5264#issuecomment-399407428
if TYPE_CHECKING:
HandshakeSchemeRegistryBaseType = UserDict[
Type[IdentitySchemeAPI], Type[HandshakeSchemeAPI]
Type[IdentitySchemeAPI], Type[HandshakeSchemeAPI[Any]]
]
else:
HandshakeSchemeRegistryBaseType = UserDict
Expand All @@ -244,8 +243,8 @@ def validate_id_nonce_signature(
class HandshakeSchemeRegistryAPI(HandshakeSchemeRegistryBaseType):
@abstractmethod
def register(
self, handshake_scheme_class: Type[HandshakeSchemeAPI]
) -> Type[HandshakeSchemeAPI]:
self, handshake_scheme_class: Type[HandshakeSchemeAPI[TSignatureInputs]]
) -> Type[HandshakeSchemeAPI[TSignatureInputs]]:
...


Expand Down
1 change: 0 additions & 1 deletion ddht/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
NEIGHBOURS_RESPONSE_ITEMS = 16
AES128_KEY_SIZE = 16 # size of an AES218 key
HKDF_INFO = b"discovery v5 key agreement"
ID_NONCE_SIGNATURE_PREFIX = b"discovery-id-nonce"
ENR_REPR_PREFIX = "enr:" # prefix used when printing an ENR
MAX_ENR_SIZE = 300 # maximum allowed size of an ENR
IP_V4_ADDRESS_ENR_KEY = b"ip"
Expand Down
6 changes: 3 additions & 3 deletions ddht/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def validate_nonce(nonce: bytes) -> None:


def aesgcm_encrypt(
key: AES128Key, nonce: Nonce, plain_text: bytes, authenticated_data: bytes
key: AES128Key, nonce: Nonce, plain_text: bytes, authenticated_data: bytes,
) -> bytes:
validate_aes128_key(key)
validate_nonce(nonce)
Expand All @@ -35,7 +35,7 @@ def aesgcm_encrypt(


def aesgcm_decrypt(
key: AES128Key, nonce: Nonce, cipher_text: bytes, authenticated_data: bytes
key: AES128Key, nonce: Nonce, cipher_text: bytes, authenticated_data: bytes,
) -> bytes:
validate_aes128_key(key)
validate_nonce(nonce)
Expand All @@ -44,7 +44,7 @@ def aesgcm_decrypt(
try:
plain_text = aesgcm.decrypt(nonce, cipher_text, authenticated_data)
except InvalidTag as error:
raise DecryptionError() from error
raise DecryptionError(str(error)) from error
else:
return plain_text

Expand Down
62 changes: 10 additions & 52 deletions ddht/handshake_schemes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from hashlib import sha256
import secrets
from typing import Tuple, Type

Expand All @@ -13,15 +12,15 @@
from eth_typing import NodeID
from eth_utils import ValidationError, encode_hex

from ddht.abc import HandshakeSchemeAPI, HandshakeSchemeRegistryAPI
from ddht.constants import AES128_KEY_SIZE, HKDF_INFO, ID_NONCE_SIGNATURE_PREFIX
from ddht.typing import AES128Key, IDNonce, SessionKeys
from ddht.abc import HandshakeSchemeAPI, HandshakeSchemeRegistryAPI, TSignatureInputs
from ddht.constants import AES128_KEY_SIZE, HKDF_INFO
from ddht.typing import AES128Key, SessionKeys


class HandshakeSchemeRegistry(HandshakeSchemeRegistryAPI):
def register(
self, handshake_scheme_class: Type[HandshakeSchemeAPI]
) -> Type[HandshakeSchemeAPI]:
self, handshake_scheme_class: Type[HandshakeSchemeAPI[TSignatureInputs]]
) -> Type[HandshakeSchemeAPI[TSignatureInputs]]:
"""Class decorator to register handshake schemes."""
is_missing_identity_scheme = (
not hasattr(handshake_scheme_class, "identity_scheme")
Expand All @@ -42,9 +41,6 @@ def register(
return handshake_scheme_class


default_handshake_scheme_registry = HandshakeSchemeRegistry()


def ecdh_agree(private_key: bytes, public_key: bytes) -> bytes:
"""
Perform the ECDH key agreement.
Expand All @@ -66,17 +62,14 @@ def ecdh_agree(private_key: bytes, public_key: bytes) -> bytes:


def hkdf_expand_and_extract(
secret: bytes,
initiator_node_id: NodeID,
recipient_node_id: NodeID,
id_nonce: IDNonce,
secret: bytes, initiator_node_id: NodeID, recipient_node_id: NodeID, salt: bytes,
) -> Tuple[bytes, bytes, bytes]:
info = b"".join((HKDF_INFO, initiator_node_id, recipient_node_id))

hkdf = HKDF(
algorithm=SHA256(),
length=3 * AES128_KEY_SIZE,
salt=id_nonce,
salt=salt,
info=info,
backend=cryptography_default_backend(),
)
Expand All @@ -94,8 +87,7 @@ def hkdf_expand_and_extract(
return initiator_key, recipient_key, auth_response_key


@default_handshake_scheme_registry.register
class V4HandshakeScheme(HandshakeSchemeAPI):
class BaseV4HandshakeScheme(HandshakeSchemeAPI[TSignatureInputs]):
identity_scheme = V4IdentityScheme

#
Expand All @@ -119,7 +111,7 @@ def compute_session_keys(
remote_public_key: bytes,
local_node_id: NodeID,
remote_node_id: NodeID,
id_nonce: IDNonce,
salt: bytes,
is_locally_initiated: bool,
) -> SessionKeys:
secret = ecdh_agree(local_private_key, remote_public_key)
Expand All @@ -130,7 +122,7 @@ def compute_session_keys(
initiator_node_id, recipient_node_id = remote_node_id, local_node_id

initiator_key, recipient_key, auth_response_key = hkdf_expand_and_extract(
secret, initiator_node_id, recipient_node_id, id_nonce
secret, initiator_node_id, recipient_node_id, salt,
)

if is_locally_initiated:
Expand All @@ -144,33 +136,6 @@ def compute_session_keys(
auth_response_key=AES128Key(auth_response_key),
)

@classmethod
def create_id_nonce_signature(
cls, *, id_nonce: IDNonce, ephemeral_public_key: bytes, private_key: bytes
) -> bytes:
private_key_object = PrivateKey(private_key)
signature_input = cls.create_id_nonce_signature_input(
id_nonce=id_nonce, ephemeral_public_key=ephemeral_public_key
)
signature = private_key_object.sign_msg_hash_non_recoverable(signature_input)
return bytes(signature)

@classmethod
def validate_id_nonce_signature(
cls,
*,
id_nonce: IDNonce,
ephemeral_public_key: bytes,
signature: bytes,
public_key: bytes,
) -> None:
signature_input = cls.create_id_nonce_signature_input(
id_nonce=id_nonce, ephemeral_public_key=ephemeral_public_key
)
cls.identity_scheme.validate_signature(
message_hash=signature_input, signature=signature, public_key=public_key
)

#
# Helpers
#
Expand Down Expand Up @@ -210,10 +175,3 @@ def validate_signature(
f"Signature {encode_hex(signature)} is not valid for message hash "
f"{encode_hex(message_hash)} and public key {encode_hex(public_key)}"
)

@classmethod
def create_id_nonce_signature_input(
cls, *, id_nonce: IDNonce, ephemeral_public_key: bytes
) -> bytes:
preimage = b"".join((ID_NONCE_SIGNATURE_PREFIX, id_nonce, ephemeral_public_key))
return sha256(preimage).digest()
4 changes: 2 additions & 2 deletions ddht/tools/driver/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ async def next_message(self) -> AnyInboundMessage:
...

@abstractmethod
async def send_ping(self, request_id: Optional[int] = None) -> PingMessage:
async def send_ping(self, request_id: Optional[bytes] = None) -> PingMessage:
...

@abstractmethod
async def send_pong(self, request_id: Optional[int] = None) -> PongMessage:
async def send_pong(self, request_id: Optional[bytes] = None) -> PongMessage:
...


Expand Down
25 changes: 12 additions & 13 deletions ddht/tools/driver/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import AsyncIterator, List, Optional

from async_generator import asynccontextmanager
from eth_utils import int_to_big_endian
import trio

from ddht.base_message import AnyInboundMessage, AnyOutboundMessage, BaseMessage
Expand Down Expand Up @@ -51,17 +52,17 @@ async def next_message(self) -> AnyInboundMessage:
return await self.channels.inbound_message_receive_channel.receive()

@no_hang
async def send_ping(self, request_id: Optional[int] = None) -> PingMessage:
async def send_ping(self, request_id: Optional[bytes] = None) -> PingMessage:
if request_id is None:
request_id = secrets.randbits(32)
request_id = int_to_big_endian(secrets.randbits(32))
message = PingMessage(request_id, self.node.enr.sequence_number)
await self.send_message(message)
return message

@no_hang
async def send_pong(self, request_id: Optional[int] = None) -> PongMessage:
async def send_pong(self, request_id: Optional[bytes] = None) -> PongMessage:
if request_id is None:
request_id = secrets.randbits(32)
request_id = int_to_big_endian(secrets.randbits(32))
message = PongMessage(
request_id,
self.node.enr.sequence_number,
Expand Down Expand Up @@ -133,22 +134,20 @@ async def transmit(self) -> AsyncIterator[None]:

@no_hang
async def send_packet(self, packet: AnyPacket) -> None:
if packet.header.source_node_id == self.initiator.node.node_id:
await self.recipient.session.handle_inbound_envelope(
if packet.dest_node_id == self.initiator.node.node_id:
await self.initiator.session.handle_inbound_envelope(
InboundEnvelope(
packet=packet, sender_endpoint=self.initiator.node.endpoint,
packet=packet, sender_endpoint=self.recipient.node.endpoint,
)
)
elif packet.header.source_node_id == self.recipient.node.node_id:
await self.initiator.session.handle_inbound_envelope(
elif packet.dest_node_id == self.recipient.node.node_id:
await self.recipient.session.handle_inbound_envelope(
InboundEnvelope(
packet=packet, sender_endpoint=self.recipient.node.endpoint,
packet=packet, sender_endpoint=self.initiator.node.endpoint,
)
)
else:
raise Exception(
f"No matching node-id: {packet.header.source_node_id.hex()}"
)
raise Exception(f"No matching node-id: {packet.dest_node_id.hex()}")

@no_hang
async def handshake(self) -> None:
Expand Down
Loading

0 comments on commit fe71b79

Please sign in to comment.