Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v5.1 Specification Update #92

Merged
merged 2 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions ddht/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ddht.base_message import BaseMessage
from ddht.boot_info import BootInfo
from ddht.typing import JSON, IDNonce, SessionKeys
from ddht.typing import JSON, SessionKeys

TAddress = TypeVar("TAddress", bound="AddressAPI")

Expand Down Expand Up @@ -177,8 +177,12 @@ def iter_all_random(self) -> Iterator[NodeID]:
...


class HandshakeSchemeAPI(ABC):
TSignatureInputs = TypeVar("TSignatureInputs")


class HandshakeSchemeAPI(ABC, Generic[TSignatureInputs]):
identity_scheme: Type[IdentitySchemeAPI]
signature_inputs_cls: Type[TSignatureInputs]

#
# Handshake
Expand All @@ -204,7 +208,7 @@ def compute_session_keys(
remote_public_key: bytes,
local_node_id: NodeID,
remote_node_id: NodeID,
id_nonce: IDNonce,
salt: bytes,
is_locally_initiated: bool,
) -> SessionKeys:
"""Compute the symmetric session keys."""
Expand All @@ -213,20 +217,15 @@ def compute_session_keys(
@classmethod
@abstractmethod
def create_id_nonce_signature(
cls, *, id_nonce: IDNonce, ephemeral_public_key: bytes, private_key: bytes
cls, *, signature_inputs: TSignatureInputs, private_key: bytes,
) -> bytes:
"""Sign an id nonce received during handshake."""
...

@classmethod
@abstractmethod
def validate_id_nonce_signature(
cls,
*,
id_nonce: IDNonce,
ephemeral_public_key: bytes,
signature: bytes,
public_key: bytes,
cls, *, signature_inputs: TSignatureInputs, signature: bytes, public_key: bytes,
) -> None:
"""Validate the id nonce signature received from a peer."""
...
Expand All @@ -235,7 +234,7 @@ def validate_id_nonce_signature(
# https://github.com/python/mypy/issues/5264#issuecomment-399407428
if TYPE_CHECKING:
HandshakeSchemeRegistryBaseType = UserDict[
Type[IdentitySchemeAPI], Type[HandshakeSchemeAPI]
Type[IdentitySchemeAPI], Type[HandshakeSchemeAPI[Any]]
]
else:
HandshakeSchemeRegistryBaseType = UserDict
Expand All @@ -244,8 +243,8 @@ def validate_id_nonce_signature(
class HandshakeSchemeRegistryAPI(HandshakeSchemeRegistryBaseType):
@abstractmethod
def register(
self, handshake_scheme_class: Type[HandshakeSchemeAPI]
) -> Type[HandshakeSchemeAPI]:
self, handshake_scheme_class: Type[HandshakeSchemeAPI[TSignatureInputs]]
) -> Type[HandshakeSchemeAPI[TSignatureInputs]]:
...


Expand Down
13 changes: 11 additions & 2 deletions ddht/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,19 @@ def __call__(
network_parser.add_argument(
"--listen-address", type=ipaddress.ip_address, help="IP address to listen on"
)
network_parser.add_argument(

bootnodes_parser_group = network_parser.add_mutually_exclusive_group()
bootnodes_parser_group.add_argument(
"--bootnode",
action=NormalizeAndAppendENR,
help="IP address to listen on",
help="ENR for custom bootnode",
dest="bootnodes",
)
bootnodes_parser_group.add_argument(
"--no-bootstrap",
help="Start without any bootnodes",
action="store_const",
const=(),
dest="bootnodes",
)

Expand Down
1 change: 0 additions & 1 deletion ddht/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
NEIGHBOURS_RESPONSE_ITEMS = 16
AES128_KEY_SIZE = 16 # size of an AES218 key
HKDF_INFO = b"discovery v5 key agreement"
ID_NONCE_SIGNATURE_PREFIX = b"discovery-id-nonce"
ENR_REPR_PREFIX = "enr:" # prefix used when printing an ENR
MAX_ENR_SIZE = 300 # maximum allowed size of an ENR
IP_V4_ADDRESS_ENR_KEY = b"ip"
Expand Down
6 changes: 3 additions & 3 deletions ddht/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def validate_nonce(nonce: bytes) -> None:


def aesgcm_encrypt(
key: AES128Key, nonce: Nonce, plain_text: bytes, authenticated_data: bytes
key: AES128Key, nonce: Nonce, plain_text: bytes, authenticated_data: bytes,
) -> bytes:
validate_aes128_key(key)
validate_nonce(nonce)
Expand All @@ -35,7 +35,7 @@ def aesgcm_encrypt(


def aesgcm_decrypt(
key: AES128Key, nonce: Nonce, cipher_text: bytes, authenticated_data: bytes
key: AES128Key, nonce: Nonce, cipher_text: bytes, authenticated_data: bytes,
) -> bytes:
validate_aes128_key(key)
validate_nonce(nonce)
Expand All @@ -44,7 +44,7 @@ def aesgcm_decrypt(
try:
plain_text = aesgcm.decrypt(nonce, cipher_text, authenticated_data)
except InvalidTag as error:
raise DecryptionError() from error
raise DecryptionError(str(error)) from error
else:
return plain_text

Expand Down
62 changes: 10 additions & 52 deletions ddht/handshake_schemes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from hashlib import sha256
import secrets
from typing import Tuple, Type

Expand All @@ -13,15 +12,15 @@
from eth_typing import NodeID
from eth_utils import ValidationError, encode_hex

from ddht.abc import HandshakeSchemeAPI, HandshakeSchemeRegistryAPI
from ddht.constants import AES128_KEY_SIZE, HKDF_INFO, ID_NONCE_SIGNATURE_PREFIX
from ddht.typing import AES128Key, IDNonce, SessionKeys
from ddht.abc import HandshakeSchemeAPI, HandshakeSchemeRegistryAPI, TSignatureInputs
from ddht.constants import AES128_KEY_SIZE, HKDF_INFO
from ddht.typing import AES128Key, SessionKeys


class HandshakeSchemeRegistry(HandshakeSchemeRegistryAPI):
def register(
self, handshake_scheme_class: Type[HandshakeSchemeAPI]
) -> Type[HandshakeSchemeAPI]:
self, handshake_scheme_class: Type[HandshakeSchemeAPI[TSignatureInputs]]
) -> Type[HandshakeSchemeAPI[TSignatureInputs]]:
"""Class decorator to register handshake schemes."""
is_missing_identity_scheme = (
not hasattr(handshake_scheme_class, "identity_scheme")
Expand All @@ -42,9 +41,6 @@ def register(
return handshake_scheme_class


default_handshake_scheme_registry = HandshakeSchemeRegistry()


def ecdh_agree(private_key: bytes, public_key: bytes) -> bytes:
"""
Perform the ECDH key agreement.
Expand All @@ -66,17 +62,14 @@ def ecdh_agree(private_key: bytes, public_key: bytes) -> bytes:


def hkdf_expand_and_extract(
secret: bytes,
initiator_node_id: NodeID,
recipient_node_id: NodeID,
id_nonce: IDNonce,
secret: bytes, initiator_node_id: NodeID, recipient_node_id: NodeID, salt: bytes,
) -> Tuple[bytes, bytes, bytes]:
info = b"".join((HKDF_INFO, initiator_node_id, recipient_node_id))

hkdf = HKDF(
algorithm=SHA256(),
length=3 * AES128_KEY_SIZE,
salt=id_nonce,
salt=salt,
info=info,
backend=cryptography_default_backend(),
)
Expand All @@ -94,8 +87,7 @@ def hkdf_expand_and_extract(
return initiator_key, recipient_key, auth_response_key


@default_handshake_scheme_registry.register
class V4HandshakeScheme(HandshakeSchemeAPI):
class BaseV4HandshakeScheme(HandshakeSchemeAPI[TSignatureInputs]):
identity_scheme = V4IdentityScheme

#
Expand All @@ -119,7 +111,7 @@ def compute_session_keys(
remote_public_key: bytes,
local_node_id: NodeID,
remote_node_id: NodeID,
id_nonce: IDNonce,
salt: bytes,
is_locally_initiated: bool,
) -> SessionKeys:
secret = ecdh_agree(local_private_key, remote_public_key)
Expand All @@ -130,7 +122,7 @@ def compute_session_keys(
initiator_node_id, recipient_node_id = remote_node_id, local_node_id

initiator_key, recipient_key, auth_response_key = hkdf_expand_and_extract(
secret, initiator_node_id, recipient_node_id, id_nonce
secret, initiator_node_id, recipient_node_id, salt,
)

if is_locally_initiated:
Expand All @@ -144,33 +136,6 @@ def compute_session_keys(
auth_response_key=AES128Key(auth_response_key),
)

@classmethod
def create_id_nonce_signature(
cls, *, id_nonce: IDNonce, ephemeral_public_key: bytes, private_key: bytes
) -> bytes:
private_key_object = PrivateKey(private_key)
signature_input = cls.create_id_nonce_signature_input(
id_nonce=id_nonce, ephemeral_public_key=ephemeral_public_key
)
signature = private_key_object.sign_msg_hash_non_recoverable(signature_input)
return bytes(signature)

@classmethod
def validate_id_nonce_signature(
cls,
*,
id_nonce: IDNonce,
ephemeral_public_key: bytes,
signature: bytes,
public_key: bytes,
) -> None:
signature_input = cls.create_id_nonce_signature_input(
id_nonce=id_nonce, ephemeral_public_key=ephemeral_public_key
)
cls.identity_scheme.validate_signature(
message_hash=signature_input, signature=signature, public_key=public_key
)

#
# Helpers
#
Expand Down Expand Up @@ -210,10 +175,3 @@ def validate_signature(
f"Signature {encode_hex(signature)} is not valid for message hash "
f"{encode_hex(message_hash)} and public key {encode_hex(public_key)}"
)

@classmethod
def create_id_nonce_signature_input(
cls, *, id_nonce: IDNonce, ephemeral_public_key: bytes
) -> bytes:
preimage = b"".join((ID_NONCE_SIGNATURE_PREFIX, id_nonce, ephemeral_public_key))
return sha256(preimage).digest()
4 changes: 2 additions & 2 deletions ddht/tools/driver/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ async def next_message(self) -> AnyInboundMessage:
...

@abstractmethod
async def send_ping(self, request_id: Optional[int] = None) -> PingMessage:
async def send_ping(self, request_id: Optional[bytes] = None) -> PingMessage:
...

@abstractmethod
async def send_pong(self, request_id: Optional[int] = None) -> PongMessage:
async def send_pong(self, request_id: Optional[bytes] = None) -> PongMessage:
...


Expand Down
25 changes: 12 additions & 13 deletions ddht/tools/driver/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import AsyncIterator, List, Optional

from async_generator import asynccontextmanager
from eth_utils import int_to_big_endian
import trio

from ddht.base_message import AnyInboundMessage, AnyOutboundMessage, BaseMessage
Expand Down Expand Up @@ -51,17 +52,17 @@ async def next_message(self) -> AnyInboundMessage:
return await self.channels.inbound_message_receive_channel.receive()

@no_hang
async def send_ping(self, request_id: Optional[int] = None) -> PingMessage:
async def send_ping(self, request_id: Optional[bytes] = None) -> PingMessage:
if request_id is None:
request_id = secrets.randbits(32)
request_id = int_to_big_endian(secrets.randbits(32))
message = PingMessage(request_id, self.node.enr.sequence_number)
await self.send_message(message)
return message

@no_hang
async def send_pong(self, request_id: Optional[int] = None) -> PongMessage:
async def send_pong(self, request_id: Optional[bytes] = None) -> PongMessage:
if request_id is None:
request_id = secrets.randbits(32)
request_id = int_to_big_endian(secrets.randbits(32))
message = PongMessage(
request_id,
self.node.enr.sequence_number,
Expand Down Expand Up @@ -133,22 +134,20 @@ async def transmit(self) -> AsyncIterator[None]:

@no_hang
async def send_packet(self, packet: AnyPacket) -> None:
if packet.header.source_node_id == self.initiator.node.node_id:
await self.recipient.session.handle_inbound_envelope(
if packet.dest_node_id == self.initiator.node.node_id:
await self.initiator.session.handle_inbound_envelope(
InboundEnvelope(
packet=packet, sender_endpoint=self.initiator.node.endpoint,
packet=packet, sender_endpoint=self.recipient.node.endpoint,
)
)
elif packet.header.source_node_id == self.recipient.node.node_id:
await self.initiator.session.handle_inbound_envelope(
elif packet.dest_node_id == self.recipient.node.node_id:
await self.recipient.session.handle_inbound_envelope(
InboundEnvelope(
packet=packet, sender_endpoint=self.recipient.node.endpoint,
packet=packet, sender_endpoint=self.initiator.node.endpoint,
)
)
else:
raise Exception(
f"No matching node-id: {packet.header.source_node_id.hex()}"
)
raise Exception(f"No matching node-id: {packet.dest_node_id.hex()}")

@no_hang
async def handshake(self) -> None:
Expand Down
Loading