From 14877fd959f20a1207fe5e8c09253a4cb93d9904 Mon Sep 17 00:00:00 2001 From: Daira Emma Hopwood Date: Tue, 12 Dec 2023 03:43:03 +0000 Subject: [PATCH] Finish the implementation of Streamlet and add tests. Signed-off-by: Daira Emma Hopwood --- simtfl/bc/chain.py | 8 +- simtfl/bft/chain.py | 107 ++++++++---- simtfl/bft/streamlet/__init__.py | 4 + simtfl/bft/streamlet/chain.py | 6 +- simtfl/bft/streamlet/node.py | 289 +++++++++++++++++++++++++++++-- simtfl/message.py | 3 + simtfl/network.py | 26 ++- simtfl/node.py | 4 +- 8 files changed, 386 insertions(+), 61 deletions(-) diff --git a/simtfl/bc/chain.py b/simtfl/bc/chain.py index a0f1d10..0aca192 100644 --- a/simtfl/bc/chain.py +++ b/simtfl/bc/chain.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Iterable, Optional +from typing import Iterable, Optional, TypeAlias from collections.abc import Sequence from dataclasses import dataclass from enum import Enum, auto @@ -245,13 +245,13 @@ def is_noncontextually_valid(self) -> bool: class BCProtocol: """A best-chain protocol.""" - Transaction: type[object] = BCTransaction + Transaction: TypeAlias = BCTransaction """The type of transactions for this protocol.""" - Context: type[object] = BCContext + Context: TypeAlias = BCContext """The type of contexts for this protocol.""" - Block: type[object] = BCBlock + Block: TypeAlias = BCBlock """The type of blocks for this protocol.""" diff --git a/simtfl/bft/chain.py b/simtfl/bft/chain.py index e491ee7..2b75f96 100644 --- a/simtfl/bft/chain.py +++ b/simtfl/bft/chain.py @@ -26,16 +26,33 @@ def __init__(self, n: int, t: int): Constructs a genesis block for a permissioned BFT protocol with `n` nodes, of which at least `t` must sign each proposal. """ + self.n = n + """The number of voters.""" + self.t = t + """The threshold of votes required for notarization.""" + self.parent = None + """The genesis block has no parent (represented as `None`).""" - def last_final(self) -> PermissionedBFTBase: - """ - Returns the last final block in this block's ancestor chain. - For the genesis block, this is itself. - """ - return self + self.length = 1 + """The genesis chain length is 1.""" + + self.last_final = self + """The last final block for the genesis block is itself.""" + + def preceq(self, other: PermissionedBFTBase): + """Return True if this block is an ancestor of `other`.""" + if self.length > other.length: + return False # optimization + return self == other or (other.parent is not None and self.preceq(other.parent)) + + def __eq__(self, other) -> bool: + return other.parent is None and (self.n, self.t) == (other.n, other.t) + + def __hash__(self) -> int: + return hash((self.n, self.t)) class PermissionedBFTBlock(PermissionedBFTBase): @@ -56,15 +73,24 @@ def __init__(self, proposal: PermissionedBFTProposal): proposal.assert_notarized() self.proposal = proposal + """The proposal for this block.""" + + assert proposal.parent is not None self.parent = proposal.parent + """The parent of this block.""" - def last_final(self): - """ - Returns the last final block in this block's ancestor chain. - This should be overridden by subclasses; the default implementation - will (inefficiently) just return the genesis block. - """ - return self if self.parent is None else self.parent.last_final() + self.length = proposal.length + """The chain length of this block.""" + + self.last_final = self.parent.last_final + """The last final block for this block.""" + + def __eq__(self, other) -> bool: + return (isinstance(other, PermissionedBFTBlock) and + (self.n, self.t, self.proposal) == (other.n, other.t, other.proposal)) + + def __hash__(self) -> int: + return hash((self.n, self.t, self.proposal)) class PermissionedBFTProposal(PermissionedBFTBase): @@ -77,8 +103,22 @@ def __init__(self, parent: PermissionedBFTBase): block. """ super().__init__(parent.n, parent.t) + self.parent = parent - self.signers = set() + """The parent block of this proposal.""" + + self.length = parent.length + 1 + """The chain length of this proposal is one greater than its parent block.""" + + self.votes = set() + """The set of voter indices that have voted for this proposal.""" + + def __eq__(self, other): + """Two proposals are equal iff they are the same object.""" + return self is other + + def __hash__(self) -> int: + return id(self) def assert_valid(self) -> None: """ @@ -102,7 +142,7 @@ def assert_notarized(self) -> None: signatures. """ self.assert_valid() - assert len(self.signers) >= self.t + assert len(self.votes) >= self.t def is_notarized(self) -> bool: """Is this proposal notarized?""" @@ -112,14 +152,13 @@ def is_notarized(self) -> bool: except AssertionError: return False - def add_signature(self, index: int) -> None: + def add_vote(self, index: int) -> None: """ - Record that the node with the given `index` has signed this proposal. - If the same node signs more than once, the subsequent signatures are - ignored. + Record that the node with the given `index` has voted for this proposal. + Calls that add the same vote more than once are ignored. """ - self.signers.add(index) - assert len(self.signers) <= self.n + self.votes.add(index) + assert len(self.votes) <= self.n __all__ = ['two_thirds_threshold', 'PermissionedBFTBase', 'PermissionedBFTBlock', 'PermissionedBFTProposal'] @@ -132,35 +171,39 @@ def test_basic(self) -> None: # Construct the genesis block. genesis = PermissionedBFTBase(5, 2) current = genesis - self.assertEqual(current.last_final(), genesis) + self.assertEqual(current.last_final, genesis) for _ in range(2): - proposal = PermissionedBFTProposal(current) + parent = current + proposal = PermissionedBFTProposal(parent) proposal.assert_valid() self.assertTrue(proposal.is_valid()) self.assertFalse(proposal.is_notarized()) - # not enough signatures - proposal.add_signature(0) + # not enough votes + proposal.add_vote(0) self.assertFalse(proposal.is_notarized()) - # same index, so we still only have one signature - proposal.add_signature(0) + # same index, so we still only have one vote + proposal.add_vote(0) self.assertFalse(proposal.is_notarized()) - # different index, now we have two signatures as required - proposal.add_signature(1) + # different index, now we have two votes as required + proposal.add_vote(1) proposal.assert_notarized() self.assertTrue(proposal.is_notarized()) current = PermissionedBFTBlock(proposal) - self.assertEqual(current.last_final(), genesis) + self.assertTrue(parent.preceq(current)) + self.assertFalse(current.preceq(parent)) + self.assertNotEqual(current, parent) + self.assertEqual(current.last_final, genesis) def test_assertions(self) -> None: genesis = PermissionedBFTBase(5, 2) proposal = PermissionedBFTProposal(genesis) self.assertRaises(AssertionError, PermissionedBFTBlock, proposal) - proposal.add_signature(0) + proposal.add_vote(0) self.assertRaises(AssertionError, PermissionedBFTBlock, proposal) - proposal.add_signature(1) + proposal.add_vote(1) _ = PermissionedBFTBlock(proposal) diff --git a/simtfl/bft/streamlet/__init__.py b/simtfl/bft/streamlet/__init__.py index bbb9dcd..e9d7395 100644 --- a/simtfl/bft/streamlet/__init__.py +++ b/simtfl/bft/streamlet/__init__.py @@ -1,3 +1,7 @@ """ An implementation of adapted-Streamlet ([CS2020] as modified in [Crosslink]). + +[CS2020] https://eprint.iacr.org/2020/088.pdf + +[Crosslink] https://hackmd.io/JqENg--qSmyqRt_RqY7Whw?view """ diff --git a/simtfl/bft/streamlet/chain.py b/simtfl/bft/streamlet/chain.py index 6d9b48b..ac64ed4 100644 --- a/simtfl/bft/streamlet/chain.py +++ b/simtfl/bft/streamlet/chain.py @@ -27,7 +27,7 @@ def __init__(self, parent: StreamletBlock | StreamletGenesis, epoch: int): """The epoch of this proposal.""" def __str__(self) -> str: - return "StreamletProposal(parent=%s, epoch=%s)" % (self.parent, self.epoch) + return f"StreamletProposal(parent={self.parent}, epoch={self.epoch}, length={self.length})" class StreamletGenesis(PermissionedBFTBase): @@ -49,7 +49,7 @@ def __init__(self, n: int): """The last final block of the genesis block is itself.""" def __str__(self) -> str: - return "StreamletGenesis(n=%s)" % (self.n,) + return f"StreamletGenesis(n={self.n})" def proposer_for_epoch(self, epoch: int): assert epoch > 0 @@ -97,4 +97,4 @@ def _compute_last_final(self) -> StreamletBlock | StreamletGenesis: (first, middle, last) = (first.parent, first, middle) def __str__(self) -> str: - return "StreamletBlock(proposal=%s)" % (self.proposal,) + return f"StreamletBlock(proposal={self.proposal})" diff --git a/simtfl/bft/streamlet/node.py b/simtfl/bft/streamlet/node.py index c55f6fe..484500d 100644 --- a/simtfl/bft/streamlet/node.py +++ b/simtfl/bft/streamlet/node.py @@ -4,6 +4,9 @@ from __future__ import annotations +from typing import Optional +from collections.abc import Sequence +from dataclasses import dataclass from ...node import SequentialNode from ...message import Message, PayloadMessage @@ -20,6 +23,35 @@ class Echo(PayloadMessage): pass +@dataclass(frozen=True) +class Ballot(Message): + """ + A ballot message, recording that a voter has voted for a `StreamletProposal`. + Ballots should not be forged unless modelling an attack that allows doing so. + """ + proposal: StreamletProposal + """The proposal.""" + voter: int + """The voter.""" + + def __str__(self) -> str: + return f"Ballot({self.proposal}, voter={self.voter})" + + +class Proposal(PayloadMessage): + """ + A message containing a `StreamletProposal`. + """ + pass + + +class Block(PayloadMessage): + """ + A message containing a `StreamletBlock`. + """ + pass + + class StreamletNode(SequentialNode): """ A Streamlet node. @@ -32,7 +64,32 @@ def __init__(self, genesis: StreamletGenesis): """ assert genesis.epoch == 0 self.genesis = genesis + """The genesis block.""" + self.voted_epoch = genesis.epoch + """The last epoch on which this node voted.""" + + self.tip: StreamletBlock | StreamletGenesis = genesis + """ + A longest chain seen by this node. The node's last final block is given by + `self.tip.last_final`. + """ + + self.proposal: Optional[StreamletProposal] = None + """The current proposal by this node, when it is the proposer.""" + + self.safety_violations: set[tuple[StreamletBlock | StreamletGenesis, + StreamletBlock | StreamletGenesis]] = set() + """The set of safety violations detected by this node.""" + + def propose(self, proposal: StreamletProposal) -> ProcessEffect: + """ + (process) Ask the node to make a proposal. + """ + assert proposal.is_valid() + assert proposal.epoch > self.voted_epoch + self.proposal = proposal + return self.broadcast(Proposal(proposal), False) def handle(self, sender: int, message: Message) -> ProcessEffect: """ @@ -42,32 +99,242 @@ def handle(self, sender: int, message: Message) -> ProcessEffect: (This causes the number of messages to blow up by a factor of `n`, but it's what the Streamlet paper specifies and is necessary for its liveness proof.) - * Received non-duplicate proposals may cause us to send a `Vote`. - * ... + * Receiving a non-duplicate `Proposal` may cause us to broadcast a `Ballot`. + * If we are the current proposer, keep track of ballots for our proposal. + * Receiving a `Block` may cause us to update our `tip`. """ if isinstance(message, Echo): message = message.payload else: - yield from self.broadcast(Echo(message)) + yield from self.broadcast(Echo(message), False) - if isinstance(message, StreamletProposal): - yield from self.handle_proposal(message) - elif isinstance(message, StreamletBlock): - yield from self.handle_block(message) + if isinstance(message, Proposal): + yield from self.handle_proposal(message.payload) + elif isinstance(message, Block): + yield from self.handle_block(message.payload) + elif isinstance(message, Ballot): + yield from self.handle_ballot(message) else: yield from super().handle(sender, message) def handle_proposal(self, proposal: StreamletProposal) -> ProcessEffect: """ (process) If we already voted in the epoch specified by the proposal or a - later epoch, ignore this proposal. + later epoch, ignore this proposal. Otherwise, cast a vote for it iff it + is valid. """ if proposal.epoch <= self.voted_epoch: - self.log("handle", + self.log("proposal", f"received proposal for epoch {proposal.epoch} but we already voted in epoch {self.voted_epoch}") return skip() - return skip() + if proposal.is_valid(): + self.log("proposal", f"voting for {proposal}") + # For now we just forget that we made a proposal if we receive a different + # valid one from another node. This is not realistic. Note that we can and + # should vote for our own proposal. + if proposal != self.proposal: + self.proposal = None + + self.voted_epoch = proposal.epoch + return self.broadcast(Ballot(proposal, self.ident), True) + else: + return skip() def handle_block(self, block: StreamletBlock) -> ProcessEffect: - raise NotImplementedError + """ + If `block.last_final` does not descend from `self.tip.last_final`, reject the block. + (In this case, if also `self.tip.last_final` does not descend from `block.last_final`, + this is a detected safety violation.) + + Otherwise, update `self.tip` to `block` iff `block` is later in lexicographic ordering + by `(length, epoch)`. + """ + if not self.tip.last_final.preceq(block.last_final): + self.log("block", f"× not ⪰ last_final: {block}") + if not block.last_final.preceq(self.tip.last_final): + self.log("block", f"! safety violation: ({block}, {self.tip})") + self.safety_violations.add((block, self.tip)) + return skip() + + # TODO: analyse tie-breaking rule. + if (self.tip.length, self.tip.epoch) >= (block.length, block.epoch): + self.log("block", f"× not updating tip: {block}") + return skip() + + self.log("block", f"✓ updating tip: {block}") + self.tip = block + return skip() + + def handle_ballot(self, ballot: Ballot) -> ProcessEffect: + """ + If we have made a proposal that is not yet notarized and the ballot is + for that proposal, add the vote. If it is now notarized, broadcast it + as a block. + """ + proposal = ballot.proposal + if proposal == self.proposal: + self.log("count", f"{ballot.voter} voted for our proposal in epoch {proposal.epoch}") + proposal.add_vote(ballot.voter) + if proposal.is_notarized(): + yield from self.broadcast(Block(StreamletBlock(proposal)), True) + # It's fine to forget that we made the proposal now. + self.proposal = None + + def final_block(self) -> StreamletBlock | StreamletGenesis: + """ + Return the last final block seen by this node. + """ + return self.tip.last_final + + +__all__ = ['Echo', 'Ballot', 'StreamletNode'] + +import unittest +from itertools import count +from simpy import Environment +from simpy.events import Process, Timeout + +from ...network import Network +from ...logging import PrintLogger + + +class TestStreamlet(unittest.TestCase): + def test_simple(self) -> None: + """ + Very simple example. + + 0 --- 1 --- 2 --- 3 + """ + self._test_last_final([0, 1, 2], + [0, 0, 2]) + + def test_figure_1(self) -> None: + """ + Figure 1: Streamlet finalization example (without the invalid 'X' proposal). + + 0 --- 2 --- 5 --- 6 --- 7 + \ + -- 1 --- 3 + + 0 - Genesis + N - Notarized block + + This diagram implies the epoch 6 block is the last-final block in the + context of the epoch 7 block, because it is in the middle of 3 blocks + with consecutive epoch numbers, and 6 is the most recent such block. + + (We don't include the block/proposal with the red X because that's not + what we're testing.) + """ + N = None + self._test_last_final([0, 0, 1, N, 2, 5, 6], + [0, 0, 0, 0, 0, 0, 6]) + + def test_complex(self) -> None: + """ + Safety Violation: due to three simultaneous properties: + + - 6 is `last_final` in the context of 7 + - 9 is `last_final` in the context of 10 + - 9 is not a descendant of 6 + + 0 --- 2 --- 5 --- 6 --- 7 + \ + -- 1 --- 3 --- 8 --- 9 --- 10 + """ + N = None + self._test_last_final([0, 0, 1, N, 2, 5, 6, 3, 8, 9], + [0, 0, 0, 0, 0, 0, 6, 0, 0, 9], + expect_divergence_at_epoch=8, + expect_safety_violations={(10, 7)}) + + def _test_last_final(self, + parent_map: Sequence[Optional[int]], + final_map: Sequence[int], + expect_divergence_at_epoch: Optional[int]=None, + expect_safety_violations: set[tuple[int, int]]=set()) -> None: + """ + This test constructs a tree of proposals with structure determined by + `parent_map`, and asserts `block.last_final` matches the structure + determined by `final_map`. + + parent_map: sequence of parent epoch numbers + final_map: sequence of final epoch numbers + expect_divergence_at_epoch: first epoch at which a block does not become the new tip + expect_safety_violations: safety violation proofs + """ + + assert len(parent_map) == len(final_map) + + # Construct the genesis block. + genesis = StreamletGenesis(3) + network = Network(Environment(), logger=PrintLogger()) + for _ in range(genesis.n): + network.add_node(StreamletNode(genesis)) + + current = genesis + self.assertEqual(current.last_final, genesis) + blocks: list[Optional[StreamletBlock | StreamletGenesis]] = [genesis] + + def run() -> ProcessEffect: + for (epoch, parent_epoch, final_epoch) in zip(count(1), parent_map, final_map): + yield Timeout(network.env, 10) + if parent_epoch is None: + blocks.append(None) + continue + + parent = blocks[parent_epoch] + assert parent is not None + proposer = network.node(genesis.proposer_for_epoch(epoch)) + proposal = StreamletProposal(parent, epoch) + self.assertEqual(proposal.length, parent.length + 1) + proposal.assert_valid() + self.assertFalse(proposal.is_notarized()) + + proposer.propose(proposal) + yield Timeout(network.env, 10) + + # The proposer should have sent the block. + assert proposer.proposal is None + + # Make a fake block `current` from the proposal so that we can append + # it to `blocks` and check its `last_final`. + current = StreamletBlock(proposal) + self.assertEqual(current.length, proposal.length) + self.assertTrue(parent.preceq(current)) + self.assertFalse(current.preceq(parent)) + self.assertEqual(len(blocks), current.epoch) + blocks.append(current) + final_block = blocks[final_epoch] + assert final_block is not None + self.assertEqual(current.last_final, final_block) + + # All nodes' tips should be the same. + tip = network.node(0).tip + for i in range(1, network.num_nodes()): + self.assertEqual(network.node(i).tip, tip) + + # If we try to create a new block on top of a chain that is not the longest, + # the nodes will ignore it. + if epoch == expect_divergence_at_epoch: + self.assertLess(current.length, tip.length) + elif expect_divergence_at_epoch is None or epoch < expect_divergence_at_epoch: + self.assertEqual(current.length, tip.length) + self.assertEqual(tip.epoch, epoch) + self.assertEqual(tip.proposal, proposal) + + for node in network.nodes: + node_final = node.final_block() + self.assertEqual(node_final, final_block, + f"epoch {node_final.epoch} != epoch {final_block.epoch}") + + for node in network.nodes: + self.assertEqual(set(((a.epoch, b.epoch) for (a, b) in node.safety_violations)), + expect_safety_violations) + + network.done = True + + Process(network.env, run()) + network.run_all() + self.assertTrue(network.done) diff --git a/simtfl/message.py b/simtfl/message.py index 003919e..d1226f0 100644 --- a/simtfl/message.py +++ b/simtfl/message.py @@ -22,3 +22,6 @@ class PayloadMessage(Message): """ payload: Any """The payload.""" + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self.payload})" diff --git a/simtfl/network.py b/simtfl/network.py index 0b98eb3..e5f3e88 100644 --- a/simtfl/network.py +++ b/simtfl/network.py @@ -30,7 +30,7 @@ def initialize(self, ident: int, env: Environment, network: Network): self.env = env self.network = network - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}" def log(self, event: str, detail: str): @@ -47,13 +47,13 @@ def send(self, target: int, message: Message, delay: Optional[Number]=None) -> P """ return self.network.send(self.ident, target, message, delay=delay) - def broadcast(self, message: Message, delay: Optional[Number]=None) -> ProcessEffect: + def broadcast(self, message: Message, include_self: bool, delay: Optional[Number]=None) -> ProcessEffect: """ (process) This method can be overridden to intercept messages being broadcast by this node. The implementation in this class calls `self.network.broadcast` with this node as the sender. """ - return self.network.broadcast(self.ident, message, delay=delay) + return self.network.broadcast(self.ident, message, include_self, delay=delay) def receive(self, sender: int, message: Message) -> ProcessEffect: """ @@ -86,8 +86,14 @@ def __init__(self, env: Environment, nodes: Optional[list[Node]]=None, delay: Nu a set of initial nodes, message propagation delay, and logger. """ self.env = env + """The `simpy.Environment`.""" + self.nodes = nodes or [] + """The nodes in this network.""" + self.delay = delay + """The message propagation delay.""" + self._logger = logger logger.header() @@ -166,19 +172,21 @@ def send(self, sender: int, target: int, message: Message, delay: Optional[Numbe # TODO: make it take some time on the sending node. return skip() - def broadcast(self, sender: int, message: Message, delay: Optional[Number]=None) -> ProcessEffect: + def broadcast(self, sender: int, message: Message, include_self: bool, + delay: Optional[Number]=None) -> ProcessEffect: """ - (process) Broadcasts a message to every other node. The message - propagation delay is normally given by `self.delay`, but can be - overridden by the `delay` parameter. + (process) Broadcasts a message to every node (including ourself only when + `include_self` is set). The message propagation delay is normally given by + `self.delay`, but can be overridden by the `delay` parameter. """ if delay is None: delay = self.delay - self.log(sender, "broadcast", f"to * with delay {delay:2d}: {message}") + c = "+" if include_self else "-" + self.log(sender, "broadcast", f"to {c}* with delay {delay:2d}: {message}") # Run `convey` in a new process for each node. for target in range(self.num_nodes()): - if target != sender: + if include_self or target != sender: Process(self.env, self.convey(delay, sender, target, message)) # Broadcasting is currently instantaneous. diff --git a/simtfl/node.py b/simtfl/node.py index f921f2c..ea00ec8 100644 --- a/simtfl/node.py +++ b/simtfl/node.py @@ -89,7 +89,7 @@ def run(self) -> ProcessEffect: while True: while len(self._mailbox) > 0: (sender, message) = self._mailbox.popleft() - self.log("handle", f"from {sender:2d}: {message}") + self.log("handle", f"from {sender:2d}: {message}") yield from self.handle(sender, message) # This naive implementation is fine because we have no actual @@ -147,7 +147,7 @@ def run(self) -> ProcessEffect: yield Timeout(self.env, 1) # This message is broadcast at time 4 and received at time 5. - yield from self.broadcast(PayloadMessage(4)) + yield from self.broadcast(PayloadMessage(4), False) class TestFramework(unittest.TestCase):