diff --git a/.daily_canary b/.daily_canary index 00bda69d9d6b..2cd4a1ac6ed1 100644 --- a/.daily_canary +++ b/.daily_canary @@ -1 +1 @@ -Run the daily CI please. \ No newline at end of file +Run the daily CI please! diff --git a/src/consensus/aft/raft.h b/src/consensus/aft/raft.h index f8220b7b28b7..71a522f1b140 100644 --- a/src/consensus/aft/raft.h +++ b/src/consensus/aft/raft.h @@ -943,6 +943,7 @@ namespace aft std::make_unique(*this, from, std::move(r)); break; } + case bft_view_change: { RequestViewChangeMsg r = @@ -1447,10 +1448,12 @@ namespace aft (state->new_view_idx > prev_idx) && (state->new_view_idx <= end_idx); LOG_DEBUG_FMT( - "Send append entries from {} to {}: {} to {} ({})", + "Send append entries from {} to {}: ({}.{}, {}.{}] ({})", state->my_node_id, to, - start_idx, + prev_term, + prev_idx, + term_of_idx, end_idx, state->commit_idx); @@ -2869,10 +2872,6 @@ namespace aft { for (auto it = nodes.begin(); it != nodes.end(); ++it) { - channels->create_channel( - it->first, - it->second.node_info.hostname, - it->second.node_info.port); send_request_vote(it->first); } } @@ -3477,22 +3476,19 @@ namespace aft auto index = state->last_idx + 1; nodes.try_emplace(node_info.first, node_info.second, index, 0); - if ( - replica_state == kv::ReplicaState::Leader || - consensus_type == ConsensusType::BFT) - { - channels->create_channel( - node_info.first, - node_info.second.hostname, - node_info.second.port); - } + channels->associate_node_address( + node_info.first, node_info.second.hostname, node_info.second.port); if (replica_state == kv::ReplicaState::Leader) { send_append_entries(node_info.first, index); } - LOG_INFO_FMT("Added raft node {}", node_info.first); + LOG_INFO_FMT( + "Added raft node {} ({}:{})", + node_info.first, + node_info.second.hostname, + node_info.second.port); } } } diff --git a/src/consensus/aft/test/logging_stub.h b/src/consensus/aft/test/logging_stub.h index 7ac292a5362e..502f0c12fd6f 100644 --- a/src/consensus/aft/test/logging_stub.h +++ b/src/consensus/aft/test/logging_stub.h @@ -162,18 +162,13 @@ namespace aft return std::nullopt; } - void create_channel( + void associate_node_address( const ccf::NodeId& peer_id, const std::string& peer_hostname, - const std::string& peer_service, - size_t message_limit = ccf::Channel::default_message_limit) override + const std::string& peer_service) override {} - void destroy_channel(const ccf::NodeId& peer_id) override {} - - void destroy_all_channels() override {} - - void close_all_outgoing() override {} + void close_channel(const ccf::NodeId& peer_id) override {} void set_endorsed_node_cert(const crypto::Pem&) override {} @@ -202,9 +197,11 @@ namespace aft return true; } - void recv_message( + bool recv_channel_message( const ccf::NodeId& from, const uint8_t* data, size_t size) override - {} + { + return true; + } void initialize( const ccf::NodeId& self_id, diff --git a/src/crypto/symmetric_key.h b/src/crypto/symmetric_key.h index c3fb445deb19..d7329e2f9e35 100644 --- a/src/crypto/symmetric_key.h +++ b/src/crypto/symmetric_key.h @@ -100,14 +100,19 @@ namespace crypto return serial_hdr; } - void deserialise(const std::vector& serial_hdr) + void deserialise(const std::vector& ser) { - auto data_ = serial_hdr.data(); - auto size = serial_hdr.size(); + auto data = ser.data(); + auto size = ser.size(); + + deserialise(data, size); + } + void deserialise(const uint8_t*& data, size_t& size) + { memcpy( - tag, serialized::read(data_, size, GCM_SIZE_TAG).data(), GCM_SIZE_TAG); - memcpy(iv, serialized::read(data_, size, SIZE_IV).data(), SIZE_IV); + tag, serialized::read(data, size, GCM_SIZE_TAG).data(), GCM_SIZE_TAG); + memcpy(iv, serialized::read(data, size, SIZE_IV).data(), SIZE_IV); } }; diff --git a/src/enclave/main.cpp b/src/enclave/main.cpp index cc0f6df7347f..3a2c0de607d1 100644 --- a/src/enclave/main.cpp +++ b/src/enclave/main.cpp @@ -26,6 +26,9 @@ std::atomic num_complete_threads = 0; threading::ThreadMessaging threading::ThreadMessaging::thread_messaging; std::atomic threading::ThreadMessaging::thread_count = 0; +std::chrono::microseconds ccf::Channel::min_gap_between_initiation_attempts( + 2'000'000); + extern "C" { CreateNodeStatus enclave_create_node( diff --git a/src/enclave/rpc_sessions.h b/src/enclave/rpc_sessions.h index 2765793125eb..d8f0acd3fd6c 100644 --- a/src/enclave/rpc_sessions.h +++ b/src/enclave/rpc_sessions.h @@ -219,11 +219,10 @@ namespace enclave { LOG_INFO_FMT( "Refusing TLS session {} inside the enclave - already have {} " - "sessions " - "from interface {} and limit is {}", + "sessions from interface {} and limit is {}", id, - listen_interface_id, per_listen_interface.open_sessions, + listen_interface_id, per_listen_interface.max_open_sessions_hard); RINGBUFFER_WRITE_MESSAGE( @@ -235,11 +234,10 @@ namespace enclave { LOG_INFO_FMT( "Soft refusing session {} (returning 503) inside the enclave - " - "already have {} " - "sessions from interface {} and limit is {}", + "already have {} sessions from interface {} and limit is {}", id, - listen_interface_id, per_listen_interface.open_sessions, + listen_interface_id, per_listen_interface.max_open_sessions_soft); auto ctx = std::make_unique(cert); @@ -256,8 +254,8 @@ namespace enclave { LOG_DEBUG_FMT( "Accepting a session {} inside the enclave from interface {}", - listen_interface_id, - id); + id, + listen_interface_id); auto ctx = std::make_unique(cert); auto session = std::make_shared( diff --git a/src/host/main.cpp b/src/host/main.cpp index b6486ee34fba..ca8aaef3d722 100644 --- a/src/host/main.cpp +++ b/src/host/main.cpp @@ -694,8 +694,7 @@ int main(int argc, char** argv) // This includes DNS resolution and potentially dynamic port assignment (if // requesting port 0). The hostname and port may be modified - after calling // it holds the final assigned values. - asynchost::NodeConnectionsTickingReconnect node( - 20ms, //< Flush reconnections every 20ms + asynchost::NodeConnections node( bp.get_dispatcher(), ledger, writer_factory, diff --git a/src/host/node_connections.h b/src/host/node_connections.h index bd7362ee666b..5934de1f098f 100644 --- a/src/host/node_connections.h +++ b/src/host/node_connections.h @@ -75,7 +75,11 @@ namespace asynchost const size_t payload_size = msg_size.value() - (size_pre_headers - size_post_headers); - associate(from); + if (!node.has_value()) + { + associate_incoming(from); + node = from; + } LOG_DEBUG_FMT( "node in: from node {}, size {}, type {}", @@ -103,48 +107,55 @@ namespace asynchost } } - virtual void associate(const ccf::NodeId&) {} + virtual void associate_incoming(const ccf::NodeId&) {} }; class IncomingBehaviour : public ConnectionBehaviour { public: size_t id; + std::optional node_id; - IncomingBehaviour(NodeConnections& parent, size_t id) : + IncomingBehaviour(NodeConnections& parent, size_t id_) : ConnectionBehaviour(parent), - id(id) + id(id_) {} - void on_disconnect() + void on_disconnect() override { - parent.incoming.erase(id); + LOG_DEBUG_FMT("Disconnecting incoming connection {}", id); + parent.unassociated_incoming.erase(id); - if (node.has_value()) + if (node_id.has_value()) { - LOG_DEBUG_FMT( - "node incoming disconnect {} with node {}", id, node.value()); - parent.associated.erase(node.value()); + parent.remove_connection(node_id.value()); } } - virtual void associate(const ccf::NodeId& n) + void associate_incoming(const ccf::NodeId& n) override { - // It is possible that a peer terminates a connection and opens a new - // one but the termination is seen by us _after_ messages on the new - // connection are received. We re-associate the connection on the latest - // received message, so that oubound messages are routed to the correct, - // most up-to-date, incoming connection. - auto search = parent.associated.find(n); - if (search != parent.associated.end() && search->second.first == id) - { - // Incoming connection is already associated with n - return; - } + node_id = n; + + const auto unassociated = parent.unassociated_incoming.find(id); + CCF_ASSERT_FMT( + unassociated != parent.unassociated_incoming.end(), + "Associating node {} with incoming ID {}, but have already forgotten " + "the incoming connection", + n, + id); + + // Always prefer this (probably) newer connection. Pathological case is + // where both nodes open outgoings to each other at the same time, both + // see the corresponding incoming connections and _drop_ their outgoing + // connections. Both have a useless incoming connection they think they + // can use. Assumption is that they progress at different rates, and one + // of them eventually spots the dead connection and opens a new one + // which succeeds. + parent.connections[n] = unassociated->second; + parent.unassociated_incoming.erase(unassociated); - parent.associated[n] = std::make_pair(id, parent.incoming.at(id)); - LOG_DEBUG_FMT("node incoming {} associated with {}", id, n); - node = n; + LOG_DEBUG_FMT( + "Node incoming connection ({}) associated with {}", id, n); } }; @@ -155,33 +166,36 @@ namespace asynchost ConnectionBehaviour(parent, node) {} - void on_bind_failed() - { - LOG_DEBUG_FMT("node bind failed: {}", node.value()); - reconnect(); - } - - void on_resolve_failed() + void on_bind_failed() override { - LOG_DEBUG_FMT("node resolve failed {}", node.value()); - reconnect(); + LOG_DEBUG_FMT( + "Disconnecting outgoing connection with {}: bind failed", + node.value()); + parent.remove_connection(node.value()); } - void on_connect_failed() + void on_resolve_failed() override { - LOG_DEBUG_FMT("node connect failed {}", node.value()); - reconnect(); + LOG_DEBUG_FMT( + "Disconnecting outgoing connection with {}: resolve failed", + node.value()); + parent.remove_connection(node.value()); } - void on_disconnect() + void on_connect_failed() override { - LOG_DEBUG_FMT("node disconnect failed {}", node.value()); - reconnect(); + LOG_DEBUG_FMT( + "Disconnecting outgoing connection with {}: connect failed", + node.value()); + parent.remove_connection(node.value()); } - void reconnect() + void on_disconnect() override { - parent.request_reconnect(node.value()); + LOG_DEBUG_FMT( + "Disconnecting outgoing connection with {}: disconnected", + node.value()); + parent.remove_connection(node.value()); } }; @@ -202,24 +216,23 @@ namespace asynchost { auto id = parent.get_next_id(); peer->set_behaviour(std::make_unique(parent, id)); - parent.incoming.emplace(id, peer); - - LOG_DEBUG_FMT("node accept {}", id); + parent.unassociated_incoming.emplace(id, peer); + LOG_DEBUG_FMT("Accepted new incoming node connection ({})", id); } }; Ledger& ledger; TCP listener; - // The lifetime of outgoing connections is handled by node channels in the - // enclave - std::unordered_map outgoing; + std::unordered_map> + node_addresses; - std::unordered_map incoming; - std::unordered_map> associated; + std::unordered_map connections; + + std::unordered_map unassociated_incoming; size_t next_id = 1; + ringbuffer::WriterPtr to_enclave; - std::set reconnect_queue; std::optional client_interface = std::nullopt; size_t client_connection_timeout; @@ -250,27 +263,56 @@ namespace asynchost messaging::Dispatcher& disp) { DISPATCHER_SET_MESSAGE_HANDLER( - disp, ccf::add_node, [this](const uint8_t* data, size_t size) { - auto [id, hostname, service] = - ringbuffer::read_message(data, size); - add_node(id, hostname, service); + disp, + ccf::associate_node_address, + [this](const uint8_t* data, size_t size) { + auto [node_id, hostname, service] = + ringbuffer::read_message(data, size); + + node_addresses[node_id] = {hostname, service}; }); DISPATCHER_SET_MESSAGE_HANDLER( - disp, ccf::remove_node, [this](const uint8_t* data, size_t size) { - auto [id] = ringbuffer::read_message(data, size); - remove_node(id); + disp, + ccf::close_node_outbound, + [this](const uint8_t* data, size_t size) { + auto [node_id] = + ringbuffer::read_message(data, size); + + remove_connection(node_id); }); DISPATCHER_SET_MESSAGE_HANDLER( disp, ccf::node_outbound, [this](const uint8_t* data, size_t size) { // Read piece-by-piece rather than all at once ccf::NodeId to = serialized::read(data, size); - auto node = find(to, true); - if (!node) + TCP outbound_connection = nullptr; { - return; + const auto connection_it = connections.find(to); + if (connection_it == connections.end()) + { + const auto address_it = node_addresses.find(to); + if (address_it == node_addresses.end()) + { + LOG_FAIL_FMT("Ignoring node_outbound to unknown node {}", to); + return; + } + + const auto& [host, service] = address_it->second; + outbound_connection = create_connection(to, host, service); + if (outbound_connection.is_null()) + { + LOG_FAIL_FMT( + "Unable to connect to {}, dropping outbound message message", + to); + return; + } + } + else + { + outbound_connection = connection_it->second; + } } // Rather than reading and reserialising, use the msg_type and from_id @@ -300,17 +342,17 @@ namespace asynchost if (framed_entries.has_value()) { frame += (uint32_t)framed_entries->size(); - node.value()->write(sizeof(uint32_t), (uint8_t*)&frame); - node.value()->write(size_to_send, data_to_send); + outbound_connection->write(sizeof(uint32_t), (uint8_t*)&frame); + outbound_connection->write(size_to_send, data_to_send); frame = (uint32_t)framed_entries->size(); - node.value()->write(frame, framed_entries->data()); + outbound_connection->write(frame, framed_entries->data()); } else { // Header-only AE - node.value()->write(sizeof(uint32_t), (uint8_t*)&frame); - node.value()->write(size_to_send, data_to_send); + outbound_connection->write(sizeof(uint32_t), (uint8_t*)&frame); + outbound_connection->write(size_to_send, data_to_send); } LOG_DEBUG_FMT( @@ -327,98 +369,44 @@ namespace asynchost LOG_DEBUG_FMT("node send to {} [{}]", to, frame); - node.value()->write(sizeof(uint32_t), (uint8_t*)&frame); - node.value()->write(size_to_send, data_to_send); + outbound_connection->write(sizeof(uint32_t), (uint8_t*)&frame); + outbound_connection->write(size_to_send, data_to_send); } }); } - void request_reconnect(const ccf::NodeId& node) - { - reconnect_queue.insert(node); - } - - void on_timer() - { - // Swap to local copy of queue. Although this should only be modified by - // this thread, it may be modified recursively (ie - executing this - // function may result in calls to request_reconnect). These recursive - // calls are queued until the next iteration - decltype(reconnect_queue) local_queue; - std::swap(reconnect_queue, local_queue); - - for (const auto& node : local_queue) - { - LOG_DEBUG_FMT("reconnecting node {}", node); - auto s = outgoing.find(node); - if (s != outgoing.end()) - { - s->second->reconnect(); - } - } - } - private: - bool add_node( - const ccf::NodeId& node, + TCP create_connection( + const ccf::NodeId& node_id, const std::string& host, const std::string& service) { - if (outgoing.find(node) != outgoing.end()) - { - LOG_FAIL_FMT("Cannot add node connection {}: already in use", node); - return false; - } - auto s = TCP(true, client_connection_timeout); - s->set_behaviour(std::make_unique(*this, node)); + s->set_behaviour(std::make_unique(*this, node_id)); if (!s->connect(host, service, client_interface)) { - LOG_DEBUG_FMT("Node failed initial connect {}", node); - return false; + LOG_FAIL_FMT( + "Failed to connect to {} on {}:{}", node_id, host, service); + return nullptr; } - outgoing.emplace(node, s); - + connections.emplace(node_id, s); LOG_DEBUG_FMT( - "Added node connection with {} ({}:{})", node, host, service); - return true; - } + "Added node connection with {} ({}:{})", node_id, host, service); - std::optional find(const ccf::NodeId& node, bool use_incoming = false) - { - auto s = outgoing.find(node); - - if (s != outgoing.end()) - { - return s->second; - } - - if (use_incoming) - { - auto s = associated.find(node); - - if (s != associated.end()) - { - return s->second.second; - } - } - - LOG_FAIL_FMT("Unknown node connection {}", node); - return std::nullopt; + return s; } - bool remove_node(const ccf::NodeId& node) + bool remove_connection(const ccf::NodeId& node) { - if (outgoing.erase(node) < 1) + if (connections.erase(node) < 1) { - LOG_FAIL_FMT("Cannot remove node connection {}: does not exist", node); + LOG_DEBUG_FMT("Cannot remove node connection {}: does not exist", node); return false; } - LOG_DEBUG_FMT("Removed outgoing node connection with {}", node); - + LOG_DEBUG_FMT("Removed node connection with {}", node); return true; } @@ -426,7 +414,7 @@ namespace asynchost { auto id = next_id++; - while (incoming.find(id) != incoming.end()) + while (unassociated_incoming.find(id) != unassociated_incoming.end()) { id = next_id++; } @@ -434,6 +422,4 @@ namespace asynchost return id; } }; - - using NodeConnectionsTickingReconnect = proxy_ptr>; } diff --git a/src/node/channels.h b/src/node/channels.h index eb09aab07c0a..01344ffca33b 100644 --- a/src/node/channels.h +++ b/src/node/channels.h @@ -10,6 +10,8 @@ #include "ds/hex.h" #include "ds/logger.h" #include "ds/serialized.h" +#include "ds/state_machine.h" +#include "enclave/enclave_time.h" #include "entities.h" #include "node_types.h" #include "tls/key_exchange.h" @@ -18,6 +20,18 @@ #include #include +// -Wpedantic flags token pasting of __VA_ARGS__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" + +#define CHANNEL_RECV_TRACE(s, ...) \ + LOG_TRACE_FMT("<- {} ({}): " s, peer_id, status.value(), ##__VA_ARGS__) +#define CHANNEL_SEND_TRACE(s, ...) \ + LOG_TRACE_FMT("-> {} ({}): " s, peer_id, status.value(), ##__VA_ARGS__) + +#define CHANNEL_RECV_FAIL(s, ...) \ + LOG_FAIL_FMT("<- {} ({}): " s, peer_id, status.value(), ##__VA_ARGS__) + namespace ccf { using SendNonce = uint64_t; @@ -68,18 +82,19 @@ namespace ccf static constexpr size_t default_message_limit = 23726566; #endif + static std::chrono::microseconds min_gap_between_initiation_attempts; + private: struct OutgoingMsg { NodeMsgType type; - std::vector raw_plain; // To be integrity-protected - std::vector raw_cipher; // To be encrypted + std::vector raw_aad; // To be integrity-protected + std::vector raw_plain; // To be encrypted - OutgoingMsg( - NodeMsgType msg_type, CBuffer raw_plain_, CBuffer raw_cipher_) : + OutgoingMsg(NodeMsgType msg_type, CBuffer raw_aad_, CBuffer raw_plain_) : type(msg_type), - raw_plain(raw_plain_), - raw_cipher(raw_cipher_) + raw_aad(raw_aad_), + raw_plain(raw_plain_) {} }; @@ -90,25 +105,20 @@ namespace ccf crypto::VerifierPtr peer_cv; crypto::Pem peer_cert; - // Notifies the host to create a new outgoing connection ringbuffer::WriterPtr to_host; NodeId peer_id; - std::string peer_hostname; - std::string peer_service; - bool outgoing; // Used for key exchange tls::KeyExchangeContext kex_ctx; - ChannelStatus status = INACTIVE; + ds::StateMachine status; + std::chrono::microseconds last_initiation_time; static constexpr size_t salt_len = 32; static constexpr size_t shared_key_size = 32; std::vector hkdf_salt; - bool key_exchange_in_progress = false; size_t message_limit = default_message_limit; // Used for AES GCM authentication/encryption std::unique_ptr recv_key; - std::unique_ptr next_recv_key; std::unique_ptr send_key; // Incremented for each tagged/encrypted message @@ -134,10 +144,7 @@ namespace ccf CBuffer cipher = nullb, Buffer plain = {}) { - if (status != ESTABLISHED) - { - throw std::logic_error("Channel is not established for verifying"); - } + status.expect(ESTABLISHED); RecvNonce recv_nonce(header.get_iv_int()); auto tid = recv_nonce.tid; @@ -158,30 +165,23 @@ namespace ccf local_nonce = &local_recv_nonce[tid].tid_seqno; } - LOG_TRACE_FMT( - "<- {}: node msg with nonce={}", - peer_id, - (const uint64_t)recv_nonce.nonce); + CHANNEL_RECV_TRACE( + "verify_or_decrypt({} bytes, {} bytes) (nonce={})", + aad.n, + cipher.n, + (size_t)recv_nonce.nonce); // Note: We must assume that some messages are dropped, i.e. we may not - // see every nonce/sequence number, but they must be increasing, except - // during key rollover, when it is reset to 1 for the new key. + // see every nonce/sequence number, but they must be increasing. - if ((recv_nonce.nonce == 1 || !recv_key) && next_recv_key) - { - LOG_TRACE_FMT("Changing to next channel receive key"); - recv_key.swap(next_recv_key); - next_recv_key.reset(); - } - else if (recv_nonce.nonce <= *local_nonce) + if (recv_nonce.nonce <= *local_nonce) { // If the nonce received has already been processed, return // See https://github.com/microsoft/CCF/issues/2492 for more details on // how this can happen around election time - LOG_TRACE_FMT( - "Received past nonce from:{}, received:{}, " + CHANNEL_RECV_TRACE( + "Received past nonce, received:{}, " "last_seen:{}, recv_nonce.tid:{}", - peer_id, reinterpret_cast(recv_nonce.nonce), *local_nonce, recv_nonce.tid); @@ -198,493 +198,547 @@ namespace ccf } size_t num_messages = send_nonce + recv_nonce.nonce; - if (num_messages >= message_limit && !key_exchange_in_progress) + if (num_messages >= message_limit) { - LOG_TRACE_FMT( - "Reached message limit ({}+{}), triggering new key exchange", + CHANNEL_RECV_TRACE( + "Reached message limit ({}+{} >= {}), triggering new key exchange", send_nonce, - (uint64_t)recv_nonce.nonce); + (uint64_t)recv_nonce.nonce, + message_limit); + reset(); initiate(); } return ret; } - public: - static constexpr size_t protocol_version = 1; - - Channel( - ringbuffer::AbstractWriterFactory& writer_factory, - const crypto::Pem& network_cert_, - crypto::KeyPairPtr node_kp_, - const crypto::Pem& node_cert_, - const NodeId& self_, - const NodeId& peer_id_, - const std::string& peer_hostname_, - const std::string& peer_service_, - size_t message_limit_ = default_message_limit) : - self(self_), - network_cert(network_cert_), - node_kp(node_kp_), - node_cert(node_cert_), - to_host(writer_factory.create_writer_to_outside()), - peer_id(peer_id_), - peer_hostname(peer_hostname_), - peer_service(peer_service_), - outgoing(true), - message_limit(message_limit_) + void send_key_exchange_init() { - RINGBUFFER_WRITE_MESSAGE( - ccf::add_node, to_host, peer_id.value(), peer_hostname, peer_service); - auto e = crypto::create_entropy(); - hkdf_salt = e->random(salt_len); - } + std::vector payload; + { + append_msg_type(payload, ChannelMsg::key_exchange_init); + append_protocol_version(payload); + append_vector(payload, kex_ctx.get_own_key_share()); + auto signature = node_kp->sign(kex_ctx.get_own_key_share()); + append_vector(payload, signature); + append_buffer(payload, {node_cert.data(), node_cert.size()}); + append_vector(payload, hkdf_salt); + } - Channel( - ringbuffer::AbstractWriterFactory& writer_factory, - const crypto::Pem& network_cert_, - crypto::KeyPairPtr node_kp_, - const crypto::Pem& node_cert_, - const NodeId& self_, - const NodeId& peer_id_, - size_t message_limit_ = default_message_limit) : - self(self_), - network_cert(network_cert_), - node_kp(node_kp_), - node_cert(node_cert_), - to_host(writer_factory.create_writer_to_outside()), - peer_id(peer_id_), - outgoing(false), - message_limit(message_limit_) - { - auto e = crypto::create_entropy(); - hkdf_salt = e->random(salt_len); + CHANNEL_SEND_TRACE( + "send_key_exchange_init: node serial: {}", + make_verifier(node_cert)->serial_number()); + + RINGBUFFER_WRITE_MESSAGE( + node_outbound, + to_host, + peer_id.value(), + NodeMsgType::channel_msg, + self.value(), + payload); } - ~Channel() + void send_key_exchange_response() { - LOG_INFO_FMT("Channel with {} is now destroyed.", peer_id); + std::vector signature; + { + auto to_sign = kex_ctx.get_own_key_share(); + const auto& peer_ks = kex_ctx.get_peer_key_share(); + to_sign.insert(to_sign.end(), peer_ks.begin(), peer_ks.end()); + signature = node_kp->sign(to_sign); + } - if (outgoing) + std::vector payload; { - RINGBUFFER_WRITE_MESSAGE(ccf::remove_node, to_host, peer_id.value()); + append_msg_type(payload, ChannelMsg::key_exchange_response); + append_protocol_version(payload); + append_vector(payload, kex_ctx.get_own_key_share()); + append_vector(payload, signature); + append_buffer(payload, {node_cert.data(), node_cert.size()}); } - } - void set_status(ChannelStatus status_) - { - status = status_; + CHANNEL_SEND_TRACE( + "send_key_exchange_response: oks={}, serialised_signed_share={}", + ds::to_hex(kex_ctx.get_own_key_share()), + ds::to_hex(payload)); + + RINGBUFFER_WRITE_MESSAGE( + node_outbound, + to_host, + peer_id.value(), + NodeMsgType::channel_msg, + self.value(), + payload); } - ChannelStatus get_status() + void send_key_exchange_final() { - return status; + std::vector payload; + { + append_msg_type(payload, ChannelMsg::key_exchange_final); + // append_protocol_version(payload); // Not sent by current protocol! + auto signature = node_kp->sign(kex_ctx.get_peer_key_share()); + append_vector(payload, signature); + } + + CHANNEL_SEND_TRACE( + "key_exchange_final: ks={}, serialised_signed_key_share={}", + ds::to_hex(kex_ctx.get_peer_key_share()), + ds::to_hex(payload)); + + RINGBUFFER_WRITE_MESSAGE( + node_outbound, + to_host, + peer_id.value(), + NodeMsgType::channel_msg, + self.value(), + payload); } - bool is_outgoing() const + void advance_connection_attempt() { - return outgoing; + if (status.check(INACTIVE)) + { + // We have no key and believe no key exchange is in process - start a + // new iteration of the key exchange protocol + initiate(); + } + else if (status.check(INITIATED)) + { + const auto time_since_initiated = + enclave::get_enclave_time() - last_initiation_time; + if (time_since_initiated >= min_gap_between_initiation_attempts) + { + // If this node attempts to initiate too early when the peer node + // starts up, they will never receive the init message (they drop it + // if it arrives too early in their state machine). The same state + // could also occur later, if the initiate message is lost in transit. + // So sometimes this node needs to re-initiate. However, if this node + // sends too fast before the channel is established, and each send + // generates a new handshake, it may constantly generate new handshake + // attempts and never succeed. Additionally, when talking to peers + // using the old channel behaviour, this node should try to avoid + // confusing them by sending multiple adjacent initiate requests - + // they will only process the first one they receive. To avoid these + // problems with initiation spam, we have a minimum delay between + // initiation attempts. This should be low enough to get reasonable + // liveness (re-attempt connections in the presence of dropped + // messages), but high enough to give successful roundtrips a chance + // to complete. + initiate(); + } + else + { + LOG_INFO_FMT( + "Ignoring advance attempt! Only {} us have elapsed", + time_since_initiated.count()); + } + } } - void set_outgoing( - const std::string& peer_hostname_, const std::string& peer_service_) + bool recv_key_exchange_init( + const uint8_t* data, size_t size, bool they_have_priority = false) { - peer_hostname = peer_hostname_; - peer_service = peer_service_; + CHANNEL_RECV_TRACE( + "recv_key_exchange_init({} bytes, {})", size, they_have_priority); - if (!outgoing) + // Parse fields from incoming message + size_t peer_version = serialized::read(data, size); + if (peer_version != protocol_version) { - RINGBUFFER_WRITE_MESSAGE( - ccf::add_node, to_host, peer_id.value(), peer_hostname, peer_service); + CHANNEL_RECV_FAIL( + "Protocol version mismatch (node={}, peer={})", + protocol_version, + peer_version); + return false; } - outgoing = true; - } - void reset_outgoing() - { - if (outgoing) + CBuffer ks = extract_buffer(data, size); + if (ks.n == 0) { - RINGBUFFER_WRITE_MESSAGE(ccf::remove_node, to_host, peer_id.value()); + CHANNEL_RECV_FAIL("Empty keyshare"); + return false; } - outgoing = false; - } - void sign_key_share( - std::vector& target, - const std::vector& ks, - bool with_salt = false, - const CBuffer extra = {}) - { - auto to_sign = ks; - to_sign.insert(to_sign.end(), extra.p, extra.p + extra.n); + CBuffer sig = extract_buffer(data, size); + if (sig.n == 0) + { + CHANNEL_RECV_FAIL("Empty signature"); + return false; + } - auto signature = node_kp->sign(to_sign); + CBuffer pc = extract_buffer(data, size); + if (pc.n == 0) + { + CHANNEL_RECV_FAIL("Empty cert"); + return false; + } - // Serialise channel key share, signature, and certificate and - // length-prefix them - auto space = - ks.size() + signature.size() + node_cert.size() + 4 * sizeof(size_t); - if (with_salt) + CBuffer salt = extract_buffer(data, size); + if (salt.n == 0) { - space += hkdf_salt.size() + sizeof(size_t); + CHANNEL_RECV_FAIL("Empty salt"); + return false; } - const auto size_before = target.size(); - target.resize(size_before + space); - auto data_ = target.data() + size_before; - serialized::write(data_, space, protocol_version); - serialized::write(data_, space, ks.size()); - serialized::write(data_, space, ks.data(), ks.size()); - serialized::write(data_, space, signature.size()); - serialized::write(data_, space, signature.data(), signature.size()); - serialized::write(data_, space, node_cert.size()); - serialized::write(data_, space, node_cert.data(), node_cert.size()); - if (with_salt) + + if (size != 0) { - serialized::write(data_, space, hkdf_salt.size()); - serialized::write(data_, space, hkdf_salt.data(), hkdf_salt.size()); + CHANNEL_RECV_FAIL("{} exccess bytes remaining", size); + return false; } - } - CBuffer extract_buffer(const uint8_t*& data, size_t& size) const - { - if (size == 0) + // Validate cert and signature in message + crypto::Pem cert; + crypto::VerifierPtr verifier; + if (!verify_peer_certificate(pc, cert, verifier)) { - return {}; + CHANNEL_RECV_FAIL("Peer certificate verification failed"); + return false; } - auto sz = serialized::read(data, size); - CBuffer r(data, sz); + if (!verify_peer_signature(ks, sig, verifier)) + { + return false; + } - if (r.n > size) + // Both nodes tried to initiate the channel, the one with priority + // wins. + if (status.check(INITIATED) && !they_have_priority) { - LOG_FAIL_FMT( - "Buffer header wants {} bytes, but only {} remain", r.n, size); - r.n = 0; + CHANNEL_RECV_TRACE("Ignoring lower priority key init"); + return true; } else { - data += r.n; - size -= r.n; + // Whatever else we _were_ doing, we've received a valid init from them + // - reset to use it + kex_ctx.reset(); + peer_cert = cert; + peer_cv = verifier; } - return r; - } + CHANNEL_RECV_TRACE( + "recv_key_exchange_init: version={} ks={} sig={} pc={} salt={}", + peer_version, + ds::to_hex(ks), + ds::to_hex(sig), + ds::to_hex(pc), + ds::to_hex(salt)); - bool verify_peer_certificate(CBuffer pc) - { - if (pc.n != 0) - { - peer_cert = crypto::Pem(pc); - peer_cv = crypto::make_verifier(peer_cert); + hkdf_salt = {salt.p, salt.p + salt.n}; - if (!peer_cv->verify_certificate({&network_cert})) - { - LOG_FAIL_FMT("Peer certificate verification failed"); - reset(); - return false; - } + kex_ctx.load_peer_key_share(ks); - LOG_TRACE_FMT( - "New peer certificate: {}\n{}", - peer_cv->serial_number(), - peer_cert.str()); - } + status.advance(WAITING_FOR_FINAL); + + // We are the responder and we return a signature over both public key + // shares back to the initiator + send_key_exchange_response(); return true; } - bool verify_peer_signature(CBuffer msg, CBuffer sig) + bool recv_key_exchange_response(const uint8_t* data, size_t size) { - LOG_TRACE_FMT( - "Verifying peer signature with peer certificate serial {}", - peer_cv ? peer_cv->serial_number() : "no peer_cv!"); + CHANNEL_RECV_TRACE("recv_key_exchange_response({} bytes)", size); - if (!peer_cv || !peer_cv->verify(msg, sig)) + if (status.value() != INITIATED) { - LOG_FAIL_FMT( - "Node channel peer signature verification failed for {} with " - "certificate serial {}", - peer_id, - peer_cv->serial_number()); + CHANNEL_RECV_FAIL("Ignoring key exchange response - not expecting it"); return false; } - return true; - } - - // Protocol overview: - // - // initiate() - // > key_exchange_init message - // consume_initiator_key_share() [by responder] - // < key_exchange_response message - // consume_responder_key_share() [by initiator] - // > key_exchange_final message - // check_peer_key_share_signature() [by responder] - // both reach status == ESTABLISHED - - bool consume_responder_key_share(const std::vector& data) - { - return consume_responder_key_share(data.data(), data.size()); - } - - bool consume_responder_key_share(const uint8_t* data, size_t size) - { - LOG_TRACE_FMT("status == {}", status); - - if (status != INITIATED && status != ESTABLISHED) + // Parse fields from incoming message + size_t peer_version = serialized::read(data, size); + if (peer_version != protocol_version) { + CHANNEL_RECV_FAIL( + "Protocol version mismatch (node={}, peer={})", + protocol_version, + peer_version); return false; } - size_t peer_version = serialized::read(data, size); CBuffer ks = extract_buffer(data, size); - CBuffer sig = extract_buffer(data, size); - CBuffer pc = extract_buffer(data, size); - - LOG_TRACE_FMT( - "From responder {}: version={} ks={} sig={} pc={}", - peer_id, - peer_version, - ds::to_hex(ks), - ds::to_hex(sig), - ds::to_hex(pc)); - - if (size != 0) + if (ks.n == 0) { - LOG_FAIL_FMT("{} exccess bytes remaining", size); + CHANNEL_RECV_FAIL("Empty keyshare"); return false; } - if (peer_version != protocol_version) + CBuffer sig = extract_buffer(data, size); + if (sig.n == 0) { - LOG_FAIL_FMT( - "Protocol version mismatch (node={}, peer={})", - protocol_version, - peer_version); + CHANNEL_RECV_FAIL("Empty signature"); return false; } - if (ks.n == 0 || sig.n == 0) + CBuffer pc = extract_buffer(data, size); + if (pc.n == 0) { + CHANNEL_RECV_FAIL("Empty cert"); return false; } - if (!verify_peer_certificate(pc)) + if (size != 0) { + CHANNEL_RECV_FAIL("{} exccess bytes remaining", size); return false; } - // We are the initiator and expect a signature over both key shares - std::vector t = {ks.p, ks.p + ks.n}; - auto oks = kex_ctx.get_own_key_share(); - t.insert(t.end(), oks.begin(), oks.end()); - - if (!verify_peer_signature(t, sig)) + // Validate cert and signature in message + crypto::Pem cert; + crypto::VerifierPtr verifier; + if (!verify_peer_certificate(pc, cert, verifier)) { + CHANNEL_RECV_FAIL("Peer certificate verification failed"); return false; } - kex_ctx.load_peer_key_share(ks); + { + // We are the initiator and expect a signature over both key shares + std::vector signed_msg = {ks.p, ks.p + ks.n}; + const auto& oks = kex_ctx.get_own_key_share(); + signed_msg.insert(signed_msg.end(), oks.begin(), oks.end()); - // Sign the peer's key share - auto signature = node_kp->sign(ks); + if (!verify_peer_signature(signed_msg, sig, verifier)) + { + // This isn't a valid signature for this key exchange attempt. + CHANNEL_RECV_FAIL("Peer certificate verification failed"); + return false; + } + } - // Serialise signature with ChannelMsg- and length- prefixes - auto space = signature.size() + 2 * sizeof(size_t); - std::vector payload(space); - auto data_ = payload.data(); - serialized::write(data_, space, ChannelMsg::key_exchange_final); - serialized::write(data_, space, signature.size()); - serialized::write(data_, space, signature.data(), signature.size()); + peer_cert = cert; + peer_cv = verifier; - RINGBUFFER_WRITE_MESSAGE( - node_outbound, - to_host, - peer_id.value(), - NodeMsgType::channel_msg, - self.value(), - payload); + kex_ctx.load_peer_key_share(ks); - LOG_TRACE_FMT( - "key_exchange_final -> {}: ks={} payload={}", - peer_id, - ds::to_hex(ks), - ds::to_hex(payload)); + send_key_exchange_final(); establish(); return true; } - bool consume_initiator_key_share( - const std::vector& data, bool priority = false) - { - return consume_initiator_key_share(data.data(), data.size(), priority); - } - - bool consume_initiator_key_share( - const uint8_t* data, size_t size, bool priority = false) + bool recv_key_exchange_final(const uint8_t* data, size_t size) { - LOG_TRACE_FMT("status == {}", status); + CHANNEL_RECV_TRACE("recv_key_exchange_final({} bytes)", size); - if (status == INITIATED || status == ESTABLISHED) - { - // Both nodes tried to initiate the channel, the one with priority wins. - if (!priority) - return true; - } - else if (status == WAITING_FOR_FINAL) + if (status.value() != WAITING_FOR_FINAL) { + CHANNEL_RECV_FAIL("Ignoring key exchange final - not expecting it"); return false; } - key_exchange_in_progress = true; + // Parse fields from incoming message + // size_t peer_version = serialized::read(data, size); + // if (peer_version != protocol_version) + // { + // CHANNEL_RECV_FAIL( + // "Protocol version mismatch (node={}, peer={})", + // protocol_version, + // peer_version); + // return false; + // } - size_t peer_version = serialized::read(data, size); - CBuffer peer_ks = extract_buffer(data, size); CBuffer sig = extract_buffer(data, size); - CBuffer pc = extract_buffer(data, size); - CBuffer salt = extract_buffer(data, size); - - LOG_TRACE_FMT( - "From initiator {}: version={} ks={} sig={} pc={} salt={}", - peer_id, - peer_version, - ds::to_hex(peer_ks), - ds::to_hex(sig), - ds::to_hex(pc), - ds::to_hex(salt)); - - if (size != 0) - { - LOG_FAIL_FMT("{} exccess bytes remaining", size); - return false; - } - - hkdf_salt = {salt.p, salt.p + salt.n}; - - if (peer_version != protocol_version) - { - LOG_FAIL_FMT( - "Protocol version mismatch (node={}, peer={})", - protocol_version, - peer_version); - return false; - } - - if (peer_ks.n == 0 || sig.n == 0) + if (sig.n == 0) { + CHANNEL_RECV_FAIL("Empty signature"); return false; } - if (!verify_peer_certificate(pc) || !verify_peer_signature(peer_ks, sig)) + if (!verify_peer_signature(kex_ctx.get_own_key_share(), sig, peer_cv)) { + CHANNEL_RECV_FAIL("Peer certificate verification failed"); return false; } - if (status == ESTABLISHED) - { - // key_ctx does not hold a key share; we need a new one. - kex_ctx.reset(); - } - - kex_ctx.load_peer_key_share(peer_ks); + establish(); - if (status != ESTABLISHED) - status = WAITING_FOR_FINAL; + return true; + } - // We are the responder and we return a signature over both public key - // shares back to the initiator + void append_protocol_version(std::vector& target) + { + const auto size_before = target.size(); + auto size = sizeof(protocol_version); + target.resize(size_before + size); + auto data = target.data() + size_before; + serialized::write(data, size, protocol_version); + } - auto space = 1 * sizeof(size_t); - std::vector payload(space); - auto data_ = payload.data(); - serialized::write(data_, space, ChannelMsg::key_exchange_response); + void append_msg_type(std::vector& target, ChannelMsg msg_type) + { + const auto size_before = target.size(); + auto size = sizeof(msg_type); + target.resize(size_before + size); + auto data = target.data() + size_before; + serialized::write(data, size, msg_type); + } - sign_key_share(payload, kex_ctx.get_own_key_share(), false, peer_ks); + void append_buffer(std::vector& target, CBuffer src) + { + const auto size_before = target.size(); + auto size = src.n + sizeof(src.n); + target.resize(size_before + size); + auto data = target.data() + size_before; + serialized::write(data, size, src.n); + serialized::write(data, size, src.p, src.n); + } - RINGBUFFER_WRITE_MESSAGE( - node_outbound, - to_host, - peer_id.value(), - NodeMsgType::channel_msg, - self.value(), - payload); + void append_vector( + std::vector& target, const std::vector& src) + { + append_buffer(target, src); + } - LOG_TRACE_FMT( - "key_exchange_response -> {}: oks={} payload={}", - peer_id, - ds::to_hex(kex_ctx.get_own_key_share()), - ds::to_hex(payload)); + public: + static constexpr size_t protocol_version = 1; - return true; + Channel( + ringbuffer::AbstractWriterFactory& writer_factory, + const crypto::Pem& network_cert_, + crypto::KeyPairPtr node_kp_, + const crypto::Pem& node_cert_, + const NodeId& self_, + const NodeId& peer_id_, + size_t message_limit_ = default_message_limit) : + self(self_), + network_cert(network_cert_), + node_kp(node_kp_), + node_cert(node_cert_), + to_host(writer_factory.create_writer_to_outside()), + peer_id(peer_id_), + status(fmt::format("Channel to {}", peer_id_), INACTIVE), + message_limit(message_limit_) + { + auto e = crypto::create_entropy(); + hkdf_salt = e->random(salt_len); } - bool check_peer_key_share_signature(const std::vector& data) + ChannelStatus get_status() { - return check_peer_key_share_signature(data.data(), data.size()); + return status.value(); } - bool check_peer_key_share_signature(const uint8_t* data, size_t size) + CBuffer extract_buffer(const uint8_t*& data, size_t& size) const { - LOG_TRACE_FMT("status == {}", status); + if (size == 0) + { + return {}; + } + + auto sz = serialized::read(data, size); + CBuffer r(data, sz); - if (status != WAITING_FOR_FINAL && status != ESTABLISHED) + if (r.n > size) { - return false; + CHANNEL_RECV_FAIL( + "Buffer header wants {} bytes, but only {} remain", r.n, size); + r.n = 0; + } + else + { + data += r.n; + size -= r.n; } - auto oks = kex_ctx.get_own_key_share(); + return r; + } - CBuffer sig = extract_buffer(data, size); + bool verify_peer_certificate( + CBuffer pc, crypto::Pem& cert, crypto::VerifierPtr& verifier) + { + if (pc.n != 0) + { + cert = crypto::Pem(pc); + verifier = crypto::make_verifier(cert); + + if (!verifier->verify_certificate({&network_cert})) + { + return false; + } - if (!verify_peer_signature(oks, sig)) + CHANNEL_RECV_TRACE( + "New peer certificate: {}\n{}", + verifier->serial_number(), + cert.str()); + + return true; + } + else + { return false; + } + } - establish(); + bool verify_peer_signature( + CBuffer msg, CBuffer sig, crypto::VerifierPtr verifier) + { + CHANNEL_RECV_TRACE( + "Verifying peer signature with peer certificate serial {}", + verifier ? verifier->serial_number() : "no peer_cv!"); + + if (!verifier || !verifier->verify(msg, sig)) + { + return false; + } return true; } + // Protocol overview: + // + // initiate() + // > key_exchange_init message + // recv_key_exchange_init() [by responder] + // < key_exchange_response message + // recv_key_exchange_response() [by initiator] + // > key_exchange_final message + // recv_key_exchange_final() [by responder] + // both reach status == ESTABLISHED + void establish() { auto shared_secret = kex_ctx.compute_shared_secret(); - std::string label_to = self.value() + peer_id.value(); - std::string label_from = peer_id.value() + self.value(); - - std::vector info = { - label_from.data(), label_from.data() + label_from.size()}; - auto key_bytes = crypto::hkdf( - crypto::MDType::SHA256, - shared_key_size, - shared_secret, - hkdf_salt, - info); - next_recv_key = crypto::make_key_aes_gcm(key_bytes); - - info = {label_to.data(), label_to.data() + label_to.size()}; - key_bytes = crypto::hkdf( - crypto::MDType::SHA256, - shared_key_size, - shared_secret, - hkdf_salt, - info); - send_key = crypto::make_key_aes_gcm(key_bytes); - kex_ctx.free_ctx(); + + { + const std::string label_from = peer_id.value() + self.value(); + const auto key_bytes = crypto::hkdf( + crypto::MDType::SHA256, + shared_key_size, + shared_secret, + hkdf_salt, + {label_from.begin(), label_from.end()}); + recv_key = crypto::make_key_aes_gcm(key_bytes); + } + + { + const std::string label_to = self.value() + peer_id.value(); + const auto key_bytes = crypto::hkdf( + crypto::MDType::SHA256, + shared_key_size, + shared_secret, + hkdf_salt, + {label_to.begin(), label_to.end()}); + send_key = crypto::make_key_aes_gcm(key_bytes); + } + send_nonce = 1; for (size_t i = 0; i < local_recv_nonce.size(); i++) { local_recv_nonce[i].main_thread_seqno = 0; local_recv_nonce[i].tid_seqno = 0; } - status = ESTABLISHED; - key_exchange_in_progress = false; + + status.advance(ESTABLISHED); LOG_INFO_FMT("Node channel with {} is now established.", peer_id); auto node_cv = make_verifier(node_cert); - LOG_TRACE_FMT( + CHANNEL_RECV_TRACE( "Node certificate serial numbers: node={} peer={}", node_cv->serial_number(), peer_cv->serial_number()); @@ -692,78 +746,66 @@ namespace ccf if (outgoing_msg.has_value()) { send( - outgoing_msg->type, - outgoing_msg->raw_plain, - outgoing_msg->raw_cipher); + outgoing_msg->type, outgoing_msg->raw_aad, outgoing_msg->raw_plain); outgoing_msg.reset(); } } void initiate() { - if (status == WAITING_FOR_FINAL) - return; - LOG_INFO_FMT("Initiating node channel with {}.", peer_id); - key_exchange_in_progress = true; - - if (status != ESTABLISHED) - { - status = INITIATED; - } - else - { - // Restart with new key exchange - kex_ctx.reset(); - peer_cert = {}; - peer_cv.reset(); - - auto e = crypto::create_entropy(); - hkdf_salt = e->random(salt_len); - } + // Begin with new key exchange + kex_ctx.reset(); + peer_cert = {}; + peer_cv.reset(); - auto space = 1 * sizeof(size_t); - std::vector payload(space); - auto data_ = payload.data(); - serialized::write(data_, space, ChannelMsg::key_exchange_init); + auto e = crypto::create_entropy(); + hkdf_salt = e->random(salt_len); - sign_key_share(payload, kex_ctx.get_own_key_share(), true); + // As a future simplification, we would like this to always be true + // (initiations must travel through reset/inactive), but it is not + // currently true + // status.expect(INACTIVE); + status.advance(INITIATED); - RINGBUFFER_WRITE_MESSAGE( - node_outbound, - to_host, - peer_id.value(), - NodeMsgType::channel_msg, - self.value(), - payload); + last_initiation_time = enclave::get_enclave_time(); - auto sn = make_verifier(node_cert)->serial_number(); - LOG_TRACE_FMT("key_exchange_init -> {} node serial: {}", peer_id, sn); + send_key_exchange_init(); } bool send(NodeMsgType type, CBuffer aad, CBuffer plain = nullb) { - if (status != ESTABLISHED) + if (!status.check(ESTABLISHED)) { - initiate(); + advance_connection_attempt(); + if (outgoing_msg.has_value()) + { + LOG_DEBUG_FMT( + "Dropping outgoing message of type {} - replaced by new outgoing " + "send of type {}", + outgoing_msg->type, + type); + } outgoing_msg = OutgoingMsg(type, aad, plain); return false; } - assert(send_key); + RecvNonce nonce( + send_nonce.fetch_add(1), threading::get_current_thread_id()); - // During key rollover, we keep recv_key to decrypt messages from the peer - // until it has rolled over too (recognized when we received a message - // with the nonce/seqno reset to 1). But, we can immediately start to send - // messages with the new send_key. + CHANNEL_SEND_TRACE( + "send({}, {} bytes, {} bytes) (nonce={})", + (size_t)type, + aad.n, + plain.n, + (size_t)nonce.nonce); GcmHdr gcm_hdr; - RecvNonce nonce( - send_nonce.fetch_add(1), threading::get_current_thread_id()); gcm_hdr.set_iv_seq(nonce.get_val()); std::vector cipher(plain.n); + assert(send_key); send_key->encrypt( gcm_hdr.get_iv(), plain, aad, cipher.data(), gcm_hdr.tag); @@ -781,9 +823,6 @@ namespace ccf RINGBUFFER_WRITE_MESSAGE( node_outbound, to_host, peer_id.value(), type, self.value(), payload); - LOG_TRACE_FMT( - "-> {}: node msg with nonce={}", peer_id, (uint64_t)nonce.nonce); - return true; } @@ -791,20 +830,23 @@ namespace ccf { // Receive authenticated message, modifying data to point to the start of // the non-authenticated plaintext payload - if (status != ESTABLISHED) + if (!status.check(ESTABLISHED)) { LOG_INFO_FMT( "Node channel with {} cannot receive authenticated message: not " "established, status={}", peer_id, - status); + status.value()); + advance_connection_attempt(); return false; } - const auto& hdr = serialized::overlay(data, size); + GcmHdr hdr; + hdr.deserialise(data, size); + if (!verify_or_decrypt(hdr, aad)) { - LOG_FAIL_FMT("Failed to verify node message from {}", peer_id); + CHANNEL_RECV_FAIL("Failed to verify node"); return false; } @@ -814,16 +856,17 @@ namespace ccf bool recv_authenticated_with_load(const uint8_t*& data, size_t& size) { // Receive authenticated message, modifying data to point to the start of - // the non-authenticated plaintex payload. data contains payload first, + // the non-authenticated plaintext payload. data contains payload first, // then GCM header - if (status != ESTABLISHED) + if (!status.check(ESTABLISHED)) { LOG_INFO_FMT( "node channel with {} cannot receive authenticated with payload " "message: not established, status={}", peer_id, - status); + status.value()); + advance_connection_attempt(); return false; } @@ -831,12 +874,13 @@ namespace ccf size_t size_ = size; serialized::skip(data_, size_, (size_ - sizeof(GcmHdr))); - const auto& hdr = serialized::overlay(data_, size_); + GcmHdr hdr; + hdr.deserialise(data_, size_); size -= sizeof(GcmHdr); if (!verify_or_decrypt(hdr, {data, size})) { - LOG_FAIL_FMT("Failed to verify node message from {}", peer_id); + CHANNEL_RECV_FAIL("Failed to verify node message with payload"); return false; } @@ -844,196 +888,144 @@ namespace ccf } std::optional> recv_encrypted( - CBuffer aad, const uint8_t* data, size_t size) + CBuffer aad, const uint8_t*& data, size_t& size) { // Receive encrypted message, returning the decrypted payload - if (status != ESTABLISHED) + if (!status.check(ESTABLISHED)) { LOG_INFO_FMT( "Node channel with {} cannot receive encrypted message: not " - "established", - peer_id); + "established, status={}", + peer_id, + status.value()); + advance_connection_attempt(); return std::nullopt; } - const auto& hdr = serialized::overlay(data, size); + GcmHdr hdr; + hdr.deserialise(data, size); + std::vector plain(size); if (!verify_or_decrypt(hdr, aad, {data, size}, plain)) { - LOG_FAIL_FMT("Failed to decrypt node message from {}", peer_id); + CHANNEL_RECV_FAIL("Failed to decrypt node message"); return std::nullopt; } return plain; } + void close_channel() + { + RINGBUFFER_WRITE_MESSAGE(close_node_outbound, to_host, peer_id.value()); + reset(); + outgoing_msg.reset(); + } + void reset() { LOG_INFO_FMT("Resetting channel with {}", peer_id); - reset_outgoing(); - status = INACTIVE; + status.advance(INACTIVE); kex_ctx.reset(); peer_cert = {}; peer_cv.reset(); recv_key.reset(); - next_recv_key.reset(); send_key.reset(); - outgoing_msg.reset(); auto e = crypto::create_entropy(); hkdf_salt = e->random(salt_len); } - }; - - class ChannelManager - { - private: - std::unordered_map> channels; - ringbuffer::AbstractWriterFactory& writer_factory; - const crypto::Pem& network_cert; - crypto::KeyPairPtr node_kp; - NodeId self; - std::optional endorsed_node_cert = std::nullopt; - std::mutex lock; - - public: - ChannelManager( - ringbuffer::AbstractWriterFactory& writer_factory_, - const crypto::Pem& network_cert_, - crypto::KeyPairPtr node_kp_, - const NodeId& self_, - std::optional endorsed_node_cert_ = std::nullopt) : - writer_factory(writer_factory_), - network_cert(network_cert_), - node_kp(node_kp_), - self(self_), - endorsed_node_cert(endorsed_node_cert_) - {} - - void set_endorsed_node_cert(const crypto::Pem& endorsed_node_cert_) - { - std::lock_guard guard(lock); - endorsed_node_cert = endorsed_node_cert_; - } - void create_channel( - const NodeId& peer_id, - const std::string& hostname, - const std::string& service, - size_t message_limit = Channel::default_message_limit) + bool recv_key_exchange_message(const uint8_t* data, size_t size) { - std::lock_guard guard(lock); - CCF_ASSERT_FMT( - endorsed_node_cert.has_value(), - "Endorsed node certificate has not yet been set"); - - auto search = channels.find(peer_id); - if (search == channels.end()) - { - LOG_DEBUG_FMT( - "Creating new outbound channel to {} ({}:{})", - peer_id, - hostname, - service); - auto channel = std::make_shared( - writer_factory, - network_cert, - node_kp, - endorsed_node_cert.value(), - self, - peer_id, - hostname, - service, - message_limit); - channels.emplace_hint(search, peer_id, std::move(channel)); - } - else if (search->second && !search->second->is_outgoing()) + try { - // Channel with peer already exists but is incoming. Create host - // outgoing connection. - LOG_DEBUG_FMT("Setting existing channel to {} as outgoing", peer_id); - search->second->set_outgoing(hostname, service); - return; + auto chmsg = serialized::read(data, size); + switch (chmsg) + { + case key_exchange_init: + { + // In the case of concurrent key_exchange_init's from both nodes, + // the one with the lower ID wins. + return recv_key_exchange_init(data, size, self < peer_id); + } + + case key_exchange_response: + { + return recv_key_exchange_response(data, size); + } + + case key_exchange_final: + { + return recv_key_exchange_final(data, size); + } + + default: + { + throw std::runtime_error(fmt::format( + "Received message with initial bytes {} from {} - not recognised " + "as a key exchange message", + chmsg, + peer_id)); + } + } } - else if (!search->second) + catch (const std::exception& e) { - LOG_INFO_FMT( - "Re-creating new outbound channel to {} ({}:{})", - peer_id, - hostname, - service); - search->second = std::make_shared( - writer_factory, - network_cert, - node_kp, - endorsed_node_cert.value(), - self, - peer_id, - hostname, - service, - message_limit); + LOG_FAIL_EXC(e.what()); + return false; } } + }; +} - void destroy_channel(const NodeId& peer_id) - { - std::lock_guard guard(lock); - auto search = channels.find(peer_id); - if (search == channels.end()) - { - LOG_FAIL_FMT( - "Cannot destroy node channel with {}: channel does not exist", - peer_id); - return; - } - - search->second = nullptr; - } +#pragma clang diagnostic pop - void destroy_all_channels() +namespace fmt +{ + template <> + struct formatter + { + template + constexpr auto parse(ParseContext& ctx) { - std::lock_guard guard(lock); - channels.clear(); + return ctx.begin(); } - void close_all_outgoing() + template + auto format(const ccf::ChannelStatus& cs, FormatContext& ctx) { - std::lock_guard guard(lock); - for (auto& c : channels) + char const* s = "Unknown"; + switch (cs) { - if (c.second && c.second->is_outgoing()) + case (ccf::INACTIVE): { - c.second->reset_outgoing(); + s = "INACTIVE"; + break; + } + case (ccf::INITIATED): + { + s = "INITIATED"; + break; + } + case (ccf::WAITING_FOR_FINAL): + { + s = "WAITING_FOR_FINAL"; + break; + } + case (ccf::ESTABLISHED): + { + s = "ESTABLISHED"; + break; } - } - } - - bool have_channel(const NodeId& nid) const - { - return channels.find(nid) != channels.end(); - } - - std::shared_ptr get(const NodeId& peer_id) - { - std::lock_guard guard(lock); - auto search = channels.find(peer_id); - if (search != channels.end()) - { - return search->second; } - // Creating temporary channel that is not outgoing (at least for now) - channels.try_emplace( - peer_id, - std::make_shared( - writer_factory, - network_cert, - node_kp, - endorsed_node_cert.value(), - self, - peer_id)); - return channels.at(peer_id); + return format_to(ctx.out(), s); } }; } + +#undef CHANNEL_RECV_TRACE +#undef CHANNEL_SEND_TRACE +#undef CHANNEL_RECV_FAIL diff --git a/src/node/node_state.h b/src/node/node_state.h index f48a8bc7c787..ab06e8f14810 100644 --- a/src/node/node_state.h +++ b/src/node/node_state.h @@ -22,6 +22,7 @@ #include "network_state.h" #include "node/http_node_client.h" #include "node/jwt_key_auto_refresh.h" +#include "node/node_to_node_channel_manager.h" #include "node/progress_tracker.h" #include "node/reconfig_id.h" #include "node/rpc/serdes.h" @@ -306,7 +307,8 @@ namespace ccf sig_tx_interval = sig_tx_interval_; sig_ms_interval = sig_ms_interval_; - n2n_channels = std::make_shared(writer_factory); + n2n_channels = + std::make_shared(writer_factory); cmd_forwarder = std::make_shared>( rpc_sessions_, n2n_channels, rpc_map, consensus_config.consensus_type); @@ -1363,6 +1365,9 @@ namespace ccf !sm.check(State::partOfPublicNetwork) && !sm.check(State::readingPrivateLedger)) { + LOG_DEBUG_FMT( + "Ignoring node msg received too early - current state is {}", + sm.value()); return; } @@ -1370,9 +1375,11 @@ namespace ccf { case channel_msg: { - n2n_channels->recv_message(from, payload_data, payload_size); + n2n_channels->recv_channel_message( + from, payload_data, payload_size); break; } + case consensus_msg: { consensus->recv_message(from, payload_data, payload_size); diff --git a/src/node/node_to_node.h b/src/node/node_to_node.h index 9587f48bcf6a..d4cf2700fddc 100644 --- a/src/node/node_to_node.h +++ b/src/node/node_to_node.h @@ -2,7 +2,6 @@ // Licensed under the Apache 2.0 License. #pragma once -#include "channels.h" #include "ds/logger.h" #include "ds/serialized.h" #include "enclave/rpc_handler.h" @@ -26,17 +25,12 @@ namespace ccf DroppedMessageException(const NodeId& from) : from(from) {} }; - virtual void create_channel( + virtual void associate_node_address( const NodeId& peer_id, const std::string& peer_hostname, - const std::string& peer_service, - size_t message_limit = Channel::default_message_limit) = 0; + const std::string& peer_service) = 0; - virtual void destroy_channel(const NodeId& peer_id) = 0; - - virtual void close_all_outgoing() = 0; - - virtual void destroy_all_channels() = 0; + virtual void close_channel(const NodeId& peer_id) = 0; virtual bool have_channel(const NodeId& nid) const = 0; @@ -93,9 +87,12 @@ namespace ccf const NodeId& from, const uint8_t*& data, size_t& size) = 0; virtual bool recv_authenticated( - const NodeId& from, CBuffer cb, const uint8_t*& data, size_t& size) = 0; + const NodeId& from, + CBuffer header, + const uint8_t*& data, + size_t& size) = 0; - virtual void recv_message( + virtual bool recv_channel_message( const NodeId& from, const uint8_t* data, size_t size) = 0; virtual void initialize( @@ -110,7 +107,7 @@ namespace ccf virtual bool send_encrypted( const NodeId& to, NodeMsgType type, - CBuffer cb, + CBuffer header, const std::vector& data) = 0; template @@ -125,7 +122,7 @@ namespace ccf template std::pair> recv_encrypted( - const NodeId& from, const uint8_t* data, size_t size) + const NodeId& from, const uint8_t*& data, size_t& size) { auto t = serialized::read(data, size); @@ -134,221 +131,6 @@ namespace ccf } virtual std::vector recv_encrypted( - const NodeId& from, CBuffer cb, const uint8_t* data, size_t size) = 0; - }; - - class NodeToNodeImpl : public NodeToNode - { - private: - std::optional self = std::nullopt; - std::unique_ptr channels; - ringbuffer::AbstractWriterFactory& writer_factory; - - public: - NodeToNodeImpl(ringbuffer::AbstractWriterFactory& writer_factory_) : - writer_factory(writer_factory_) - {} - - void initialize( - const NodeId& self_id, - const crypto::Pem& network_cert, - crypto::KeyPairPtr node_kp, - const std::optional& node_cert = std::nullopt) override - { - CCF_ASSERT_FMT( - !self.has_value(), - "Calling initialize more than once, previous id:{}, new id:{}", - self.value(), - self_id); - - if ( - node_cert.has_value() && - make_verifier(node_cert.value())->is_self_signed()) - { - LOG_FAIL_FMT( - "Refusing to initialize node-to-node channels with self-signed node " - "certificate"); - return; - } - - self = self_id; - channels = std::make_unique( - writer_factory, network_cert, node_kp, self.value(), node_cert); - } - - void set_endorsed_node_cert(const crypto::Pem& endorsed_node_cert) override - { - channels->set_endorsed_node_cert(endorsed_node_cert); - } - - void create_channel( - const NodeId& peer_id, - const std::string& hostname, - const std::string& service, - size_t message_limit = Channel::default_message_limit) override - { - if (peer_id == self.value()) - { - return; - } - - channels->create_channel(peer_id, hostname, service, message_limit); - } - - void destroy_channel(const NodeId& peer_id) override - { - if (peer_id == self.value()) - { - return; - } - - channels->destroy_channel(peer_id); - } - - void close_all_outgoing() override - { - channels->close_all_outgoing(); - } - - void destroy_all_channels() override - { - channels->destroy_all_channels(); - } - - bool have_channel(const NodeId& nid) const override - { - return channels->have_channel(nid); - } - - bool send_authenticated( - const NodeId& to, - NodeMsgType type, - const uint8_t* data, - size_t size) override - { - auto n2n_channel = channels->get(to); - return n2n_channel->send(type, {data, size}); - } - - bool recv_authenticated( - const NodeId& from, - CBuffer cb, - const uint8_t*& data, - size_t& size) override - { - auto n2n_channel = channels->get(from); - // Receiving after a channel has been destroyed is ok. - return n2n_channel ? n2n_channel->recv_authenticated(cb, data, size) : - true; - } - - bool send_encrypted( - const NodeId& to, - NodeMsgType type, - CBuffer cb, - const std::vector& data) override - { - auto n2n_channel = channels->get(to); - return n2n_channel ? n2n_channel->send(type, cb, data) : true; - } - - bool recv_authenticated_with_load( - const NodeId& from, const uint8_t*& data, size_t& size) override - { - auto n2n_channel = channels->get(from); - return n2n_channel ? - n2n_channel->recv_authenticated_with_load(data, size) : - true; - } - - std::vector recv_encrypted( - const NodeId& from, CBuffer cb, const uint8_t* data, size_t size) override - { - auto n2n_channel = channels->get(from); - - if (!n2n_channel) - return {}; - - auto plain = n2n_channel->recv_encrypted(cb, data, size); - if (!plain.has_value()) - { - throw DroppedMessageException(from); - } - - return plain.value(); - } - - void process_key_exchange_init( - const NodeId& from, const uint8_t* data, size_t size) - { - LOG_DEBUG_FMT("key_exchange_init from {}", from); - - // In the case of concurrent key_exchange_init's from both nodes, the one - // with the lower ID wins. - - auto n2n_channel = channels->get(from); - if (n2n_channel) - n2n_channel->consume_initiator_key_share(data, size, self < from); - } - - void process_key_exchange_response( - const NodeId& from, const uint8_t* data, size_t size) - { - LOG_DEBUG_FMT("key_exchange_response from {}", from); - auto n2n_channel = channels->get(from); - if (n2n_channel) - n2n_channel->consume_responder_key_share(data, size); - } - - void process_key_exchange_final( - const NodeId& from, const uint8_t* data, size_t size) - { - LOG_DEBUG_FMT("key_exchange_final from {}", from); - auto n2n_channel = channels->get(from); - if ( - n2n_channel && !n2n_channel->check_peer_key_share_signature(data, size)) - { - n2n_channel->reset(); - } - } - - void recv_message( - const NodeId& from, const uint8_t* data, size_t size) override - { - try - { - auto chmsg = serialized::read(data, size); - switch (chmsg) - { - case key_exchange_init: - { - process_key_exchange_init(from, data, size); - break; - } - - case key_exchange_response: - { - process_key_exchange_response(from, data, size); - break; - } - - case key_exchange_final: - { - process_key_exchange_final(from, data, size); - break; - } - - default: - { - } - break; - } - } - catch (const std::exception& e) - { - LOG_FAIL_EXC(e.what()); - return; - } - } + const NodeId& from, CBuffer header, const uint8_t* data, size_t size) = 0; }; } diff --git a/src/node/node_to_node_channel_manager.h b/src/node/node_to_node_channel_manager.h new file mode 100644 index 000000000000..1f77d42d15d9 --- /dev/null +++ b/src/node/node_to_node_channel_manager.h @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "channels.h" +#include "node/node_to_node.h" + +namespace ccf +{ + class NodeToNodeChannelManager : public NodeToNode + { + private: + ringbuffer::AbstractWriterFactory& writer_factory; + ringbuffer::WriterPtr to_host; + + std::unordered_map> channels; + std::mutex lock; //< Protects access to channels map + + struct ThisNode + { + NodeId node_id; + const crypto::Pem& network_cert; + crypto::KeyPairPtr node_kp; + std::optional endorsed_node_cert = std::nullopt; + }; + std::unique_ptr this_node; //< Not available at construction, only + // after calling initialize() + + size_t message_limit = Channel::default_message_limit; + + std::shared_ptr get_channel(const NodeId& peer_id) + { + CCF_ASSERT_FMT( + this_node == nullptr || this_node->node_id != peer_id, + "Requested channel with self {}", + peer_id); + + std::lock_guard guard(lock); + CCF_ASSERT_FMT( + this_node != nullptr && this_node->endorsed_node_cert.has_value(), + "Endorsed node certificate has not yet been set"); + + auto search = channels.find(peer_id); + if (search != channels.end()) + { + return search->second; + } + + // Create channel + channels.try_emplace( + peer_id, + std::make_shared( + writer_factory, + this_node->network_cert, + this_node->node_kp, + this_node->endorsed_node_cert.value(), + this_node->node_id, + peer_id, + message_limit)); + return channels.at(peer_id); + } + + public: + NodeToNodeChannelManager( + ringbuffer::AbstractWriterFactory& writer_factory_) : + writer_factory(writer_factory_), + to_host(writer_factory_.create_writer_to_outside()) + {} + + void initialize( + const NodeId& self_id, + const crypto::Pem& network_cert, + crypto::KeyPairPtr node_kp, + const std::optional& node_cert) override + { + CCF_ASSERT_FMT( + this_node == nullptr, + "Calling initialize more than once, previous id:{}, new id:{}", + this_node->node_id, + self_id); + + if ( + node_cert.has_value() && + make_verifier(node_cert.value())->is_self_signed()) + { + LOG_INFO_FMT( + "Refusing to initialize node-to-node channels with " + "self-signed node certificate."); + return; + } + + this_node = std::unique_ptr( + new ThisNode{self_id, network_cert, node_kp, node_cert}); + } + + void set_endorsed_node_cert(const crypto::Pem& endorsed_node_cert) override + { + std::lock_guard guard(lock); + this_node->endorsed_node_cert = endorsed_node_cert; + } + + void set_message_limit(size_t message_limit_) + { + message_limit = message_limit_; + } + + virtual void associate_node_address( + const NodeId& peer_id, + const std::string& peer_hostname, + const std::string& peer_service) override + { + RINGBUFFER_WRITE_MESSAGE( + ccf::associate_node_address, + to_host, + peer_id.value(), + peer_hostname, + peer_service); + } + + void close_channel(const NodeId& peer_id) override + { + get_channel(peer_id)->close_channel(); + } + + bool have_channel(const ccf::NodeId& nid) const override + { + return channels.find(nid) != channels.end(); + } + + ChannelStatus get_status(const NodeId& peer_id) + { + return get_channel(peer_id)->get_status(); + } + + bool send_authenticated( + const NodeId& to, + NodeMsgType type, + const uint8_t* data, + size_t size) override + { + CCF_ASSERT_FMT( + this_node != nullptr, + "Calling send_authenticated before channel manager is initialized"); + + return get_channel(to)->send(type, {data, size}); + } + + bool send_encrypted( + const NodeId& to, + NodeMsgType type, + CBuffer header, + const std::vector& data) override + { + CCF_ASSERT_FMT( + this_node != nullptr, + "Calling send_encrypted (to {}) before channel manager is initialized", + to); + + return get_channel(to)->send(type, header, data); + } + + bool recv_authenticated( + const NodeId& from, + CBuffer header, + const uint8_t*& data, + size_t& size) override + { + CCF_ASSERT_FMT( + this_node != nullptr, + "Calling recv_authenticated (from {}) before channel manager is " + "initialized", + from); + + return get_channel(from)->recv_authenticated(header, data, size); + } + + bool recv_authenticated_with_load( + const NodeId& from, const uint8_t*& data, size_t& size) override + { + CCF_ASSERT_FMT( + this_node != nullptr, + "Calling recv_authenticated_with_load (from {}) before channel manager " + "is initialized", + from); + + return get_channel(from)->recv_authenticated_with_load(data, size); + } + + std::vector recv_encrypted( + const NodeId& from, + CBuffer header, + const uint8_t* data, + size_t size) override + { + CCF_ASSERT_FMT( + this_node != nullptr, + "Calling recv_encrypted (from {}) before channel manager is " + "initialized", + from); + + auto plain = get_channel(from)->recv_encrypted(header, data, size); + if (!plain.has_value()) + { + throw DroppedMessageException(from); + } + + return plain.value(); + } + + bool recv_channel_message( + const NodeId& from, const uint8_t* data, size_t size) override + { + CCF_ASSERT_FMT( + this_node != nullptr, + "Calling recv_message (from {}) before channel manager is " + "initialized", + from); + + return get_channel(from)->recv_key_exchange_message(data, size); + } + + // NB: Only used by tests! + bool recv_channel_message(const NodeId& from, std::vector&& body) + { + return recv_channel_message(from, body.data(), body.size()); + } + }; +} diff --git a/src/node/node_types.h b/src/node/node_types.h index c104363586fc..878e1a563c2c 100644 --- a/src/node/node_types.h +++ b/src/node/node_types.h @@ -68,13 +68,11 @@ namespace ccf /// Node-to-node related ringbuffer messages enum : ringbuffer::Message { - ///@{ /// Change the network nodes. Enclave -> Host - DEFINE_RINGBUFFER_MSG_TYPE(add_node), - DEFINE_RINGBUFFER_MSG_TYPE(remove_node), - ///@} + DEFINE_RINGBUFFER_MSG_TYPE(associate_node_address), /// Receive data from another node. Host -> Enclave + /// Args are (msg_type, from_id, payload) DEFINE_RINGBUFFER_MSG_TYPE(node_inbound), /// Send data to another node. Enclave -> Host @@ -82,12 +80,14 @@ namespace ccf /// The host may inspect the first 3, and should write the last 3 (to /// produce an equivalent node_inbound on the receiving node) DEFINE_RINGBUFFER_MSG_TYPE(node_outbound), + + /// Close connection to another node. Enclave -> Host + DEFINE_RINGBUFFER_MSG_TYPE(close_node_outbound) }; } DECLARE_RINGBUFFER_MESSAGE_PAYLOAD( - ccf::add_node, ccf::NodeId::Value, std::string, std::string); -DECLARE_RINGBUFFER_MESSAGE_PAYLOAD(ccf::remove_node, ccf::NodeId::Value); + ccf::associate_node_address, ccf::NodeId::Value, std::string, std::string); DECLARE_RINGBUFFER_MESSAGE_PAYLOAD( ccf::node_inbound, ccf::NodeMsgType, @@ -98,4 +98,6 @@ DECLARE_RINGBUFFER_MESSAGE_PAYLOAD( ccf::NodeId::Value, ccf::NodeMsgType, ccf::NodeId::Value, - serializer::ByteRange); \ No newline at end of file + serializer::ByteRange); +DECLARE_RINGBUFFER_MESSAGE_PAYLOAD( + ccf::close_node_outbound, ccf::NodeId::Value); diff --git a/src/node/rpc/forwarder.h b/src/node/rpc/forwarder.h index 7a8c1d67944a..009f520c5bcf 100644 --- a/src/node/rpc/forwarder.h +++ b/src/node/rpc/forwarder.h @@ -128,6 +128,9 @@ namespace ccf std::pair> r; try { + LOG_TRACE_FMT("Receiving forwarded command of {} bytes", size); + LOG_TRACE_FMT(" => {:02x}", fmt::join(data, data + size, "")); + r = n2n_channels->template recv_encrypted( from, data, size); } @@ -195,6 +198,9 @@ namespace ccf std::pair> r; try { + LOG_TRACE_FMT("Receiving response of {} bytes", size); + LOG_TRACE_FMT(" => {:02x}", fmt::join(data, data + size, "")); + r = n2n_channels->template recv_encrypted( from, data, size); } @@ -233,11 +239,16 @@ namespace ccf return m; } - void recv_message(const NodeId& from, const uint8_t* data, size_t size) + void recv_message(const ccf::NodeId& from, const uint8_t* data, size_t size) { try { auto forwarded_msg = serialized::peek(data, size); + LOG_TRACE_FMT( + "recv_message({}, {} bytes) (type={})", + from, + size, + (size_t)forwarded_msg); switch (forwarded_msg) { @@ -285,17 +296,12 @@ namespace ccf return; } - if (!send_forwarded_response( - ctx->session->client_session_id, - from, - fwd_handler->process_forwarded(ctx))) - { - LOG_FAIL_FMT("Could not send forwarded response to {}", from); - } - else - { - LOG_DEBUG_FMT("Sending forwarded response to {}", from); - } + // Ignore return value - false only means it is pending + send_forwarded_response( + ctx->session->client_session_id, + from, + fwd_handler->process_forwarded(ctx)); + LOG_DEBUG_FMT("Sending forwarded response to {}", from); } break; } diff --git a/src/node/rpc/frontend.h b/src/node/rpc/frontend.h index db0b3f73786c..2e1b80abcf52 100644 --- a/src/node/rpc/frontend.h +++ b/src/node/rpc/frontend.h @@ -91,8 +91,9 @@ namespace ccf { auto primary_id = consensus->primary(); - if ( - primary_id.has_value() && + if (primary_id.has_value()) + { + // Ignore return value - false only means it is pending cmd_forwarder->forward_command( ctx, primary_id.value(), @@ -100,8 +101,8 @@ namespace ccf endpoints::ExecuteOutsideConsensus::Never ? consensus->active_nodes() : std::set(), - ctx->session->caller_cert)) - { + ctx->session->caller_cert); + // Indicate that the RPC has been forwarded to primary LOG_TRACE_FMT("RPC forwarded to primary {}", primary_id.value()); return std::nullopt; diff --git a/src/node/test/channels.cpp b/src/node/test/channels.cpp index f4eddc90cbc1..a79b81d131cc 100644 --- a/src/node/test/channels.cpp +++ b/src/node/test/channels.cpp @@ -5,20 +5,32 @@ #include "crypto/verifier.h" #include "ds/hex.h" #include "node/entities.h" -#include "node/node_to_node.h" +#include "node/node_to_node_channel_manager.h" #include "node/node_types.h" #include #include #include +#include #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN #include +namespace enclave +{ + std::atomic* host_time = nullptr; + std::chrono::microseconds last_value(0); +} + +namespace ccf +{ + std::chrono::microseconds Channel::min_gap_between_initiation_attempts(0); +} + threading::ThreadMessaging threading::ThreadMessaging::thread_messaging; std::atomic threading::ThreadMessaging::thread_count = 0; -constexpr auto buffer_size = 1024 * 16; +constexpr auto buffer_size = 1024 * 8; auto in_buffer_1 = std::make_unique(buffer_size); auto out_buffer_1 = std::make_unique(buffer_size); @@ -38,8 +50,8 @@ using namespace ccf; static constexpr auto msg_size = 64; using MsgType = std::array; -static NodeId self = std::string("self"); -static NodeId peer = std::string("peer"); +static NodeId nid1 = std::string("nid1"); +static NodeId nid2 = std::string("nid2"); static constexpr auto default_curve = crypto::CurveID::SECP384R1; @@ -47,6 +59,7 @@ template struct NodeOutboundMsg { NodeId from; + NodeId to; NodeMsgType type; T authenticated_hdr; std::vector payload; @@ -76,40 +89,44 @@ auto read_outbound_msgs(ringbuffer::Circuit& circuit) { std::vector> msgs; - circuit.read_from_inside().read( - -1, [&](ringbuffer::Message m, const uint8_t* data, size_t size) { - switch (m) - { - case node_outbound: - { - serialized::read( - data, size); // Ignore destination node id - auto msg_type = serialized::read(data, size); - NodeId from = serialized::read(data, size); - T aad; - if (size > sizeof(T)) - aad = serialized::read(data, size); - auto payload = serialized::read(data, size, size); - msgs.push_back(NodeOutboundMsg{from, msg_type, aad, payload}); - break; - } - case add_node: - { - LOG_DEBUG_FMT("Add node msg!"); - break; - } - case remove_node: - { - LOG_DEBUG_FMT("Remove node msg!"); - break; - } - default: + // A call to ringbuffer::Reader::read() may return 0 when there are still + // messages to read, when it reaches the end of the buffer. The next call to + // read() will correctly start at the beginning of the buffer and read these + // messages. So to make sure we always get the messages we expect in this + // test, read twice. + for (size_t i = 0; i < 2; ++i) + { + circuit.read_from_inside().read( + -1, [&](ringbuffer::Message m, const uint8_t* data, size_t size) { + switch (m) { - LOG_DEBUG_FMT("Outbound message is not expected: {}", m); - REQUIRE(false); + case node_outbound: + { + NodeId to = serialized::read(data, size); + auto msg_type = serialized::read(data, size); + NodeId from = serialized::read(data, size); + T aad; + if (size > sizeof(T)) + aad = serialized::read(data, size); + auto payload = serialized::read(data, size, size); + msgs.push_back( + NodeOutboundMsg{from, to, msg_type, aad, payload}); + break; + } + case associate_node_address: + case close_node_outbound: + { + // Ignored + break; + } + default: + { + LOG_INFO_FMT("Outbound message is not expected: {}", m); + REQUIRE(false); + } } - } - }); + }); + } return msgs; } @@ -117,42 +134,35 @@ auto read_outbound_msgs(ringbuffer::Circuit& circuit) auto read_node_msgs(ringbuffer::Circuit& circuit) { std::vector> add_node_msgs; - std::vector remove_node_msgs; circuit.read_from_inside().read( -1, [&](ringbuffer::Message m, const uint8_t* data, size_t size) { switch (m) { - case add_node: + case ccf::associate_node_address: { auto [id, hostname, service] = - ringbuffer::read_message(data, size); + ringbuffer::read_message(data, size); add_node_msgs.push_back(std::make_tuple(id, hostname, service)); break; } - case remove_node: - { - auto [id] = ringbuffer::read_message(data, size); - remove_node_msgs.push_back(id); - break; - } default: { - LOG_DEBUG_FMT("Outbound message is not expected: {}", m); + LOG_INFO_FMT("Outbound message is not expected: {}", m); REQUIRE(false); } } }); - return std::make_pair(add_node_msgs, remove_node_msgs); + return add_node_msgs; } NodeOutboundMsg get_first( ringbuffer::Circuit& circuit, NodeMsgType msg_type) { auto outbound_msgs = read_outbound_msgs(circuit); - REQUIRE(outbound_msgs.size() == 1); + REQUIRE(outbound_msgs.size() >= 1); auto msg = outbound_msgs[0]; const auto* data_ = msg.payload.data(); auto size_ = msg.payload.size(); @@ -180,10 +190,10 @@ TEST_CASE("Client/Server key exchange") REQUIRE(!make_verifier(channel2_cert)->is_self_signed()); - auto channel1 = - Channel(wf1, network_cert, channel1_kp, channel1_cert, self, peer); - auto channel2 = - Channel(wf2, network_cert, channel2_kp, channel2_cert, peer, self); + auto channels1 = NodeToNodeChannelManager(wf1); + channels1.initialize(nid1, network_cert, channel1_kp, channel1_cert); + auto channels2 = NodeToNodeChannelManager(wf2); + channels2.initialize(nid2, network_cert, channel2_kp, channel2_cert); MsgType msg; msg.fill(0x42); @@ -191,17 +201,17 @@ TEST_CASE("Client/Server key exchange") INFO("Trying to tag/verify before channel establishment"); { // Try sending on channel1 twice - REQUIRE_FALSE( - channel1.send(NodeMsgType::consensus_msg, {msg.begin(), msg.size()})); - REQUIRE_FALSE( - channel1.send(NodeMsgType::consensus_msg, {msg.begin(), msg.size()})); + REQUIRE_FALSE(channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.begin(), msg.size())); + REQUIRE_FALSE(channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.begin(), msg.size())); } std::vector channel1_signed_key_share; INFO("Extract key share, signature, certificate from messages"); { - // Every send has been replaced with a new channel establishment message + // Attempting to send has produced a new channel establishment message auto msgs = read_outbound_msgs(eio1); REQUIRE(msgs.size() == 2); REQUIRE(msgs[0].type == channel_msg); @@ -210,17 +220,19 @@ TEST_CASE("Client/Server key exchange") #ifndef DETERMINISTIC_ECDSA // Signing twice should have produced different signatures - REQUIRE(msgs[0].unauthenticated_data() != msgs[1].unauthenticated_data()); + REQUIRE(msgs[0].data() != msgs[1].data()); #endif - channel1_signed_key_share = msgs[0].unauthenticated_data(); + // Use the latter attempt - it is the state channel1 is working with + channel1_signed_key_share = msgs[1].data(); } INFO("Load peer key share and check signature"); { - REQUIRE(channel2.consume_initiator_key_share(channel1_signed_key_share)); - REQUIRE(channel1.get_status() == INITIATED); - REQUIRE(channel2.get_status() == WAITING_FOR_FINAL); + REQUIRE(channels2.recv_channel_message( + nid1, std::move(channel1_signed_key_share))); + REQUIRE(channels1.get_status(nid2) == INITIATED); + REQUIRE(channels2.get_status(nid1) == WAITING_FOR_FINAL); } std::vector channel2_signed_key_share; @@ -231,15 +243,16 @@ TEST_CASE("Client/Server key exchange") auto msgs = read_outbound_msgs(eio2); REQUIRE(msgs.size() == 1); REQUIRE(msgs[0].type == channel_msg); - channel2_signed_key_share = msgs[0].unauthenticated_data(); + channel2_signed_key_share = msgs[0].data(); REQUIRE(read_outbound_msgs(eio1).size() == 0); } INFO("Load responder key share and check signature"); { - REQUIRE(channel1.consume_responder_key_share(channel2_signed_key_share)); - REQUIRE(channel1.get_status() == ESTABLISHED); - REQUIRE(channel2.get_status() == WAITING_FOR_FINAL); + REQUIRE(channels1.recv_channel_message( + nid2, std::move(channel2_signed_key_share))); + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == WAITING_FOR_FINAL); } std::vector initiator_signature; @@ -251,7 +264,7 @@ TEST_CASE("Client/Server key exchange") REQUIRE(msgs.size() == 2); REQUIRE(msgs[0].type == channel_msg); REQUIRE(msgs[1].type == consensus_msg); - initiator_signature = msgs[0].unauthenticated_data(); + initiator_signature = msgs[0].data(); auto md = msgs[1].data(); REQUIRE(md.size() == msg.size() + sizeof(GcmHdr)); @@ -262,9 +275,10 @@ TEST_CASE("Client/Server key exchange") INFO("Cross-check responder signature and establish channels"); { - REQUIRE(channel2.check_peer_key_share_signature(initiator_signature)); - REQUIRE(channel1.get_status() == ESTABLISHED); - REQUIRE(channel2.get_status() == ESTABLISHED); + REQUIRE( + channels2.recv_channel_message(nid1, std::move(initiator_signature))); + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == ESTABLISHED); } INFO("Receive queued message"); @@ -274,13 +288,14 @@ TEST_CASE("Client/Server key exchange") auto payload = queued_msg.payload; const auto* data = payload.data(); auto size = payload.size(); - channel2.recv_authenticated({hdr.begin(), hdr.size()}, data, size); + REQUIRE(channels2.recv_authenticated( + nid1, {hdr.begin(), hdr.size()}, data, size)); } INFO("Protect integrity of message (peer1 -> peer2)"); { - REQUIRE( - channel1.send(NodeMsgType::consensus_msg, {msg.begin(), msg.size()})); + REQUIRE(channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.begin(), msg.size())); auto outbound_msgs = read_outbound_msgs(eio1); REQUIRE(outbound_msgs.size() == 1); auto msg_ = outbound_msgs[0]; @@ -288,7 +303,8 @@ TEST_CASE("Client/Server key exchange") auto size_ = msg_.payload.size(); REQUIRE(msg_.type == NodeMsgType::consensus_msg); - REQUIRE(channel2.recv_authenticated( + REQUIRE(channels2.recv_authenticated( + nid1, {msg_.authenticated_hdr.begin(), msg_.authenticated_hdr.size()}, data_, size_)); @@ -296,8 +312,8 @@ TEST_CASE("Client/Server key exchange") INFO("Protect integrity of message (peer2 -> peer1)"); { - REQUIRE( - channel2.send(NodeMsgType::consensus_msg, {msg.begin(), msg.size()})); + REQUIRE(channels2.send_authenticated( + nid1, NodeMsgType::consensus_msg, msg.begin(), msg.size())); auto outbound_msgs = read_outbound_msgs(eio2); REQUIRE(outbound_msgs.size() == 1); auto msg_ = outbound_msgs[0]; @@ -305,7 +321,8 @@ TEST_CASE("Client/Server key exchange") auto size_ = msg_.payload.size(); REQUIRE(msg_.type == NodeMsgType::consensus_msg); - REQUIRE(channel1.recv_authenticated( + REQUIRE(channels1.recv_authenticated( + nid2, {msg_.authenticated_hdr.begin(), msg_.authenticated_hdr.size()}, data_, size_)); @@ -313,8 +330,8 @@ TEST_CASE("Client/Server key exchange") INFO("Tamper with message"); { - REQUIRE( - channel1.send(NodeMsgType::consensus_msg, {msg.begin(), msg.size()})); + REQUIRE(channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.begin(), msg.size())); auto outbound_msgs = read_outbound_msgs(eio1); REQUIRE(outbound_msgs.size() == 1); auto msg_ = outbound_msgs[0]; @@ -323,7 +340,8 @@ TEST_CASE("Client/Server key exchange") auto size_ = msg_.payload.size(); REQUIRE(msg_.type == NodeMsgType::consensus_msg); - REQUIRE_FALSE(channel2.recv_authenticated( + REQUIRE_FALSE(channels2.recv_authenticated( + nid1, {msg_.authenticated_hdr.begin(), msg_.authenticated_hdr.size()}, data_, size_)); @@ -332,45 +350,33 @@ TEST_CASE("Client/Server key exchange") INFO("Encrypt message (peer1 -> peer2)"); { std::vector plain_text(128, 0x1); - REQUIRE(channel1.send( - NodeMsgType::consensus_msg, {msg.begin(), msg.size()}, plain_text)); + REQUIRE(channels1.send_encrypted( + nid2, NodeMsgType::consensus_msg, {msg.begin(), msg.size()}, plain_text)); - auto outbound_msgs = read_outbound_msgs(eio1); - REQUIRE(outbound_msgs.size() == 1); - auto msg_ = outbound_msgs[0]; - const auto* data_ = msg_.payload.data(); - auto size_ = msg_.payload.size(); - REQUIRE(msg_.type == NodeMsgType::consensus_msg); + auto msg_ = get_first(eio1, NodeMsgType::consensus_msg); + auto decrypted = channels2.recv_encrypted( + nid1, + {msg_.authenticated_hdr.data(), msg_.authenticated_hdr.size()}, + msg_.payload.data(), + msg_.payload.size()); - auto decrypted = channel2.recv_encrypted( - {msg_.authenticated_hdr.begin(), msg_.authenticated_hdr.size()}, - data_, - size_); - - REQUIRE(decrypted.has_value()); - REQUIRE(decrypted.value() == plain_text); + REQUIRE(decrypted == plain_text); } INFO("Encrypt message (peer2 -> peer1)"); { - std::vector plain_text(128, 0x1); - REQUIRE(channel2.send( - NodeMsgType::consensus_msg, {msg.begin(), msg.size()}, plain_text)); - - auto outbound_msgs = read_outbound_msgs(eio2); - REQUIRE(outbound_msgs.size() == 1); - auto msg_ = outbound_msgs[0]; - const auto* data_ = msg_.payload.data(); - auto size_ = msg_.payload.size(); - REQUIRE(msg_.type == NodeMsgType::consensus_msg); - - auto decrypted = channel1.recv_encrypted( - {msg_.authenticated_hdr.begin(), msg_.authenticated_hdr.size()}, - data_, - size_); - - REQUIRE(decrypted.has_value()); - REQUIRE(decrypted.value() == plain_text); + std::vector plain_text(128, 0x2); + REQUIRE(channels2.send_encrypted( + nid1, NodeMsgType::consensus_msg, {msg.begin(), msg.size()}, plain_text)); + + auto msg_ = get_first(eio2, NodeMsgType::consensus_msg); + auto decrypted = channels1.recv_encrypted( + nid2, + {msg_.authenticated_hdr.data(), msg_.authenticated_hdr.size()}, + msg_.payload.data(), + msg_.payload.size()); + + REQUIRE(decrypted == plain_text); } } @@ -387,58 +393,74 @@ TEST_CASE("Replay and out-of-order") auto channel2_csr = channel2_kp->create_csr("CN=Node2"); auto channel2_cert = network_kp->sign_csr(network_cert, channel2_csr, {}); - auto channel1 = - Channel(wf1, network_cert, channel1_kp, channel1_cert, self, peer); - auto channel2 = - Channel(wf2, network_cert, channel2_kp, channel2_cert, peer, self); + auto channels1 = NodeToNodeChannelManager(wf1); + channels1.initialize(nid1, network_cert, channel1_kp, channel1_cert); + auto channels2 = NodeToNodeChannelManager(wf2); + channels2.initialize(nid2, network_cert, channel2_kp, channel2_cert); MsgType msg; msg.fill(0x42); INFO("Establish channels"); { - channel1.initiate(); + channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.data(), msg.size()); auto msgs = read_outbound_msgs(eio1); REQUIRE(msgs.size() == 1); REQUIRE(msgs[0].type == channel_msg); - auto channel1_signed_key_share = msgs[0].unauthenticated_data(); + auto channel1_signed_key_share = msgs[0].data(); - REQUIRE(channel2.consume_initiator_key_share(channel1_signed_key_share)); + REQUIRE(channels2.recv_channel_message( + nid1, std::move(channel1_signed_key_share))); msgs = read_outbound_msgs(eio2); REQUIRE(msgs.size() == 1); REQUIRE(msgs[0].type == channel_msg); - auto channel2_signed_key_share = msgs[0].unauthenticated_data(); - REQUIRE(channel1.consume_responder_key_share(channel2_signed_key_share)); + auto channel2_signed_key_share = msgs[0].data(); + REQUIRE(channels1.recv_channel_message( + nid2, std::move(channel2_signed_key_share))); msgs = read_outbound_msgs(eio1); - REQUIRE(msgs.size() == 1); + REQUIRE(msgs.size() == 2); REQUIRE(msgs[0].type == channel_msg); - auto initiator_signature = msgs[0].unauthenticated_data(); + auto initiator_signature = msgs[0].data(); + + REQUIRE( + channels2.recv_channel_message(nid1, std::move(initiator_signature))); + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == ESTABLISHED); + + REQUIRE(msgs[1].type == consensus_msg); - REQUIRE(channel2.check_peer_key_share_signature(initiator_signature)); - REQUIRE(channel1.get_status() == ESTABLISHED); - REQUIRE(channel2.get_status() == ESTABLISHED); + const auto* payload_data = msgs[1].payload.data(); + auto payload_size = msgs[1].payload.size(); + REQUIRE(channels2.recv_authenticated( + nid1, + {msgs[1].authenticated_hdr.data(), msgs[1].authenticated_hdr.size()}, + payload_data, + payload_size)); } NodeOutboundMsg first_msg, first_msg_copy; INFO("Replay same message"); { - REQUIRE( - channel1.send(NodeMsgType::consensus_msg, {msg.begin(), msg.size()})); + REQUIRE(channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.begin(), msg.size())); auto outbound_msgs = read_outbound_msgs(eio1); REQUIRE(outbound_msgs.size() == 1); first_msg = outbound_msgs[0]; - REQUIRE(first_msg.from == self); + REQUIRE(first_msg.from == nid1); + REQUIRE(first_msg.to == nid2); auto msg_copy = first_msg; first_msg_copy = first_msg; const auto* data_ = first_msg.payload.data(); auto size_ = first_msg.payload.size(); REQUIRE(first_msg.type == NodeMsgType::consensus_msg); - REQUIRE(channel2.recv_authenticated( + REQUIRE(channels2.recv_authenticated( + nid1, {first_msg.authenticated_hdr.begin(), first_msg.authenticated_hdr.size()}, data_, size_)); @@ -446,7 +468,8 @@ TEST_CASE("Replay and out-of-order") // Replay data_ = msg_copy.payload.data(); size_ = msg_copy.payload.size(); - REQUIRE_FALSE(channel2.recv_authenticated( + REQUIRE_FALSE(channels2.recv_authenticated( + nid1, {msg_copy.authenticated_hdr.begin(), msg_copy.authenticated_hdr.size()}, data_, size_)); @@ -454,16 +477,16 @@ TEST_CASE("Replay and out-of-order") INFO("Issue more messages and replay old one"); { - REQUIRE( - channel1.send(NodeMsgType::consensus_msg, {msg.begin(), msg.size()})); + REQUIRE(channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.begin(), msg.size())); REQUIRE(read_outbound_msgs(eio1).size() == 1); - REQUIRE( - channel1.send(NodeMsgType::consensus_msg, {msg.begin(), msg.size()})); + REQUIRE(channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.begin(), msg.size())); REQUIRE(read_outbound_msgs(eio1).size() == 1); - REQUIRE( - channel1.send(NodeMsgType::consensus_msg, {msg.begin(), msg.size()})); + REQUIRE(channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.begin(), msg.size())); auto outbound_msgs = read_outbound_msgs(eio1); REQUIRE(outbound_msgs.size() == 1); auto msg_ = outbound_msgs[0]; @@ -471,14 +494,16 @@ TEST_CASE("Replay and out-of-order") auto size_ = msg_.payload.size(); REQUIRE(msg_.type == NodeMsgType::consensus_msg); - REQUIRE(channel2.recv_authenticated( + REQUIRE(channels2.recv_authenticated( + nid1, {msg_.authenticated_hdr.begin(), msg_.authenticated_hdr.size()}, data_, size_)); const auto* first_msg_data_ = first_msg_copy.payload.data(); auto first_msg_size_ = first_msg_copy.payload.size(); - REQUIRE_FALSE(channel2.recv_authenticated( + REQUIRE_FALSE(channels2.recv_authenticated( + nid1, {first_msg_copy.authenticated_hdr.begin(), first_msg_copy.authenticated_hdr.size()}, first_msg_data_, @@ -491,26 +516,42 @@ TEST_CASE("Replay and out-of-order") read_outbound_msgs(eio2).size(); REQUIRE(n == 0); - channel1.initiate(); - REQUIRE(channel1.get_status() == ESTABLISHED); - REQUIRE(channel2.get_status() == ESTABLISHED); - - auto fst = get_first(eio1, NodeMsgType::channel_msg); - REQUIRE( - channel2.consume_initiator_key_share(fst.unauthenticated_data(), true)); - REQUIRE(channel2.get_status() == ESTABLISHED); - fst = get_first(eio2, NodeMsgType::channel_msg); - REQUIRE(channel1.consume_responder_key_share(fst.unauthenticated_data())); - auto msgs = read_outbound_msgs(eio1); - REQUIRE(msgs.size() == 1); - REQUIRE( - channel2.check_peer_key_share_signature(msgs[0].unauthenticated_data())); - REQUIRE(channel1.get_status() == ESTABLISHED); - REQUIRE(channel2.get_status() == ESTABLISHED); - - REQUIRE( - channel1.send(NodeMsgType::consensus_msg, {msg.begin(), msg.size()})); - fst = get_first(eio1, NodeMsgType::consensus_msg); + channels1.close_channel(nid2); + REQUIRE(channels1.get_status(nid2) == INACTIVE); + REQUIRE(channels2.get_status(nid1) == ESTABLISHED); + + channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.data(), msg.size()); + REQUIRE(channels1.get_status(nid2) == INITIATED); + REQUIRE(channels2.get_status(nid1) == ESTABLISHED); + + REQUIRE(channels2.recv_channel_message( + nid1, get_first(eio1, NodeMsgType::channel_msg).data())); + REQUIRE(channels1.get_status(nid2) == INITIATED); + REQUIRE(channels2.get_status(nid1) == WAITING_FOR_FINAL); + + REQUIRE(channels1.recv_channel_message( + nid2, get_first(eio2, NodeMsgType::channel_msg).data())); + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == WAITING_FOR_FINAL); + + auto messages_1to2 = read_outbound_msgs(eio1); + REQUIRE(messages_1to2.size() == 2); + REQUIRE(messages_1to2[0].type == NodeMsgType::channel_msg); + REQUIRE(channels2.recv_channel_message(nid1, messages_1to2[0].data())); + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == ESTABLISHED); + + REQUIRE(messages_1to2[1].type == NodeMsgType::consensus_msg); + auto final_msg = messages_1to2[1]; + const auto* payload_data = final_msg.payload.data(); + auto payload_size = final_msg.payload.size(); + + REQUIRE(channels2.recv_authenticated( + nid1, + {final_msg.authenticated_hdr.data(), final_msg.authenticated_hdr.size()}, + payload_data, + payload_size)); } } @@ -518,38 +559,49 @@ TEST_CASE("Host connections") { auto network_kp = crypto::make_key_pair(default_curve); auto network_cert = network_kp->self_sign("CN=Network"); + auto channel_kp = crypto::make_key_pair(default_curve); - auto channel_cert = channel_kp->self_sign("CN=Node"); - auto channel_manager = - ChannelManager(wf1, network_cert, channel_kp, self, channel_cert); + auto channel_csr = channel_kp->create_csr("CN=Node"); + auto channel_cert = network_kp->sign_csr(network_cert, channel_csr, {}); - INFO("New channel creates host connection"); + auto channel_manager = NodeToNodeChannelManager(wf1); + channel_manager.initialize(nid1, network_cert, channel_kp, channel_cert); + + INFO("New node association is sent as ringbuffer message"); { - channel_manager.create_channel(peer, "hostname", "port"); - auto [add_node_msgs, remove_node_msgs] = read_node_msgs(eio1); + channel_manager.associate_node_address(nid2, "hostname", "port"); + auto add_node_msgs = read_node_msgs(eio1); REQUIRE(add_node_msgs.size() == 1); - REQUIRE(remove_node_msgs.size() == 0); - REQUIRE(std::get<0>(add_node_msgs[0]) == peer); + REQUIRE(std::get<0>(add_node_msgs[0]) == nid2); REQUIRE(std::get<1>(add_node_msgs[0]) == "hostname"); REQUIRE(std::get<2>(add_node_msgs[0]) == "port"); } - INFO("Retrieving unknown channel does not create host connection"); + INFO( + "Trying to talk to node will initiate key exchange, regardless of IP " + "association"); { NodeId unknown_peer_id = std::string("unknown_peer"); - channel_manager.get(unknown_peer_id); - auto [add_node_msgs, remove_node_msgs] = read_node_msgs(eio1); - REQUIRE(add_node_msgs.size() == 0); - REQUIRE(remove_node_msgs.size() == 0); + MsgType msg; + msg.fill(0x42); + channel_manager.send_authenticated( + unknown_peer_id, NodeMsgType::consensus_msg, msg.data(), msg.size()); + auto outbound = read_outbound_msgs(eio1); + REQUIRE(outbound.size() == 1); + REQUIRE(outbound[0].type == channel_msg); } +} - INFO("Destroying channel closes host connection"); +static std::vector> get_all_msgs( + std::set eios) +{ + std::vector> res; + for (auto& eio : eios) { - channel_manager.destroy_channel(peer); - auto [add_node_msgs, remove_node_msgs] = read_node_msgs(eio1); - REQUIRE(add_node_msgs.size() == 0); - REQUIRE(remove_node_msgs.size() == 1); + auto msgs = read_outbound_msgs(*eio); + res.insert(res.end(), msgs.begin(), msgs.end()); } + return res; } TEST_CASE("Concurrent key exchange init") @@ -565,92 +617,85 @@ TEST_CASE("Concurrent key exchange init") auto channel2_csr = channel2_kp->create_csr("CN=Node2"); auto channel2_cert = network_kp->sign_csr(network_cert, channel2_csr, {}); - auto channel1 = - Channel(wf1, network_cert, channel1_kp, channel1_cert, self, peer); - auto channel2 = - Channel(wf2, network_cert, channel2_kp, channel2_cert, peer, self); + auto channels1 = NodeToNodeChannelManager(wf1); + channels1.initialize(nid1, network_cert, channel1_kp, channel1_cert); + auto channels2 = NodeToNodeChannelManager(wf2); + channels2.initialize(nid2, network_cert, channel2_kp, channel2_cert); MsgType msg; msg.fill(0x42); - INFO("Channel 1 wins"); { - channel1.initiate(); - channel2.initiate(); + INFO("Channel 2 wins"); + channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.data(), msg.size()); + channels2.send_authenticated( + nid1, NodeMsgType::consensus_msg, msg.data(), msg.size()); - REQUIRE(channel1.get_status() == INITIATED); - REQUIRE(channel2.get_status() == INITIATED); + REQUIRE(channels1.get_status(nid2) == INITIATED); + REQUIRE(channels2.get_status(nid1) == INITIATED); auto fst1 = get_first(eio1, NodeMsgType::channel_msg); auto fst2 = get_first(eio2, NodeMsgType::channel_msg); - REQUIRE( - channel1.consume_initiator_key_share(fst2.unauthenticated_data(), true)); - REQUIRE( - channel2.consume_initiator_key_share(fst1.unauthenticated_data(), false)); + REQUIRE(channels1.recv_channel_message(nid2, fst2.data())); + REQUIRE(channels2.recv_channel_message(nid1, fst1.data())); - REQUIRE(channel1.get_status() == WAITING_FOR_FINAL); - REQUIRE(channel2.get_status() == INITIATED); + REQUIRE(channels1.get_status(nid2) == WAITING_FOR_FINAL); + REQUIRE(channels2.get_status(nid1) == INITIATED); fst1 = get_first(eio1, NodeMsgType::channel_msg); - REQUIRE(channel2.consume_responder_key_share(fst1.unauthenticated_data())); + REQUIRE(channels2.recv_channel_message(nid1, fst1.data())); fst2 = get_first(eio2, NodeMsgType::channel_msg); - REQUIRE( - channel1.check_peer_key_share_signature(fst2.unauthenticated_data())); + REQUIRE(channels1.recv_channel_message(nid2, fst2.data())); - REQUIRE(channel1.get_status() == ESTABLISHED); - REQUIRE(channel2.get_status() == ESTABLISHED); + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == ESTABLISHED); } - channel1.reset(); - channel2.reset(); + channels1.close_channel(nid2); + channels2.close_channel(nid1); + + read_outbound_msgs(eio1); + read_outbound_msgs(eio2); - INFO("Channel 2 wins"); { - channel1.initiate(); - channel2.initiate(); + INFO("Channel 1 wins"); + // Node 2 is higher priority, so its init attempt will win if they happen + // concurrently. However if node 1's init is received first, node 2 will use + // it. - REQUIRE(channel1.get_status() == INITIATED); - REQUIRE(channel2.get_status() == INITIATED); + channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.data(), msg.size()); - auto fst1 = get_first(eio1, NodeMsgType::channel_msg); - auto fst2 = get_first(eio2, NodeMsgType::channel_msg); + REQUIRE(channels1.get_status(nid2) == INITIATED); + REQUIRE(channels2.get_status(nid1) == INACTIVE); - REQUIRE( - channel1.consume_initiator_key_share(fst2.unauthenticated_data(), false)); - REQUIRE( - channel2.consume_initiator_key_share(fst1.unauthenticated_data(), true)); + // Node 2 receives the init _before_ any excuse to init themselves + auto fst1 = get_first(eio1, NodeMsgType::channel_msg); + REQUIRE(channels2.recv_channel_message(nid1, fst1.data())); + channels2.send_authenticated( + nid1, NodeMsgType::consensus_msg, msg.data(), msg.size()); - REQUIRE(channel1.get_status() == INITIATED); - REQUIRE(channel2.get_status() == WAITING_FOR_FINAL); + REQUIRE(channels1.get_status(nid2) == INITIATED); + REQUIRE(channels2.get_status(nid1) == WAITING_FOR_FINAL); - fst2 = get_first(eio2, NodeMsgType::channel_msg); + auto fst2 = get_first(eio2, NodeMsgType::channel_msg); - REQUIRE(channel1.consume_responder_key_share(fst2.unauthenticated_data())); + REQUIRE(channels1.recv_channel_message(nid2, fst2.data())); fst1 = get_first(eio1, NodeMsgType::channel_msg); - REQUIRE( - channel2.check_peer_key_share_signature(fst1.unauthenticated_data())); + REQUIRE(channels2.recv_channel_message(nid1, fst1.data())); - REQUIRE(channel1.get_status() == ESTABLISHED); - REQUIRE(channel2.get_status() == ESTABLISHED); + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == ESTABLISHED); } -} -static std::vector> get_all_msgs( - std::set eios) -{ - std::vector> res; - for (auto& eio : eios) - { - auto msgs = read_outbound_msgs(*eio); - res.insert(res.end(), msgs.begin(), msgs.end()); - } - return res; + get_all_msgs({&eio1, &eio2}); } struct CurveChoices @@ -705,12 +750,12 @@ TEST_CASE("Full NodeToNode test") msg.fill(0x42); INFO("Set up channels"); - NodeToNodeImpl n2n1(wf1), n2n2(wf2); + NodeToNodeChannelManager n2n1(wf1), n2n2(wf2); n2n1.initialize(ni1, network_cert, channel1_kp, channel1_cert); - n2n1.create_channel(ni2, "", "", message_limit); + n2n1.set_message_limit(message_limit); n2n2.initialize(ni2, network_cert, channel2_kp, channel2_cert); - n2n2.create_channel(ni1, "", "", message_limit); + n2n2.set_message_limit(message_limit); srand(0); // keep it deterministic @@ -743,8 +788,7 @@ TEST_CASE("Full NodeToNode test") { case NodeMsgType::channel_msg: { - const auto msg_body = msg.data(); - n2n.recv_message(msg.from, msg_body.data(), msg_body.size()); + n2n.recv_channel_message(msg.from, msg.data()); auto d = msg.data(); const uint8_t* data = d.data(); @@ -776,4 +820,383 @@ TEST_CASE("Full NodeToNode test") REQUIRE(actual_rollovers >= desired_rollovers); } } +} + +TEST_CASE("Interrupted key exchange") +{ + auto network_kp = crypto::make_key_pair(default_curve); + auto network_cert = network_kp->self_sign("CN=Network"); + + auto channel1_kp = crypto::make_key_pair(default_curve); + auto channel1_csr = channel1_kp->create_csr("CN=Node1"); + auto channel1_cert = network_kp->sign_csr(network_cert, channel1_csr, {}); + + auto channel2_kp = crypto::make_key_pair(default_curve); + auto channel2_csr = channel2_kp->create_csr("CN=Node2"); + auto channel2_cert = network_kp->sign_csr(network_cert, channel2_csr, {}); + + auto channels1 = NodeToNodeChannelManager(wf1); + channels1.initialize(nid1, network_cert, channel1_kp, channel1_cert); + auto channels2 = NodeToNodeChannelManager(wf2); + channels2.initialize(nid2, network_cert, channel2_kp, channel2_cert); + + std::vector msg; + msg.push_back(0x1); + msg.push_back(0x0); + msg.push_back(0x10); + msg.push_back(0x42); + + enum class DropStage + { + InitiationMessage, + ResponseMessage, + FinalMessage, + NoDrops, + }; + + DropStage drop_stage; + for (const auto drop_stage : { + DropStage::NoDrops, + DropStage::FinalMessage, + DropStage::ResponseMessage, + DropStage::InitiationMessage, + }) + { + INFO("Drop stage is ", (size_t)drop_stage); + + auto n = read_outbound_msgs(eio1).size() + + read_outbound_msgs(eio2).size(); + REQUIRE(n == 0); + + channels1.close_channel(nid2); + channels2.close_channel(nid1); + REQUIRE(channels1.get_status(nid2) == INACTIVE); + REQUIRE(channels2.get_status(nid1) == INACTIVE); + + channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.data(), msg.size()); + + REQUIRE(channels1.get_status(nid2) == INITIATED); + REQUIRE(channels2.get_status(nid1) == INACTIVE); + + auto initiator_key_share_msg = get_first(eio1, NodeMsgType::channel_msg); + if (drop_stage > DropStage::InitiationMessage) + { + REQUIRE( + channels2.recv_channel_message(nid1, initiator_key_share_msg.data())); + + REQUIRE(channels1.get_status(nid2) == INITIATED); + REQUIRE(channels2.get_status(nid1) == WAITING_FOR_FINAL); + + auto responder_key_share_msg = get_first(eio2, NodeMsgType::channel_msg); + if (drop_stage > DropStage::ResponseMessage) + { + REQUIRE( + channels1.recv_channel_message(nid2, responder_key_share_msg.data())); + + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == WAITING_FOR_FINAL); + + auto initiator_key_exchange_final_msg = + get_first(eio1, NodeMsgType::channel_msg); + if (drop_stage > DropStage::FinalMessage) + { + REQUIRE(channels2.recv_channel_message( + nid1, initiator_key_exchange_final_msg.data())); + + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == ESTABLISHED); + } + } + } + + INFO("Later attempts to connect should succeed"); + { + // Discard any pending messages + channels1.close_channel(nid2); + channels2.close_channel(nid1); + + SUBCASE("") + { + INFO("Node 1 attempts to connect"); + channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, msg.data(), msg.size()); + + REQUIRE(channels2.recv_channel_message( + nid1, get_first(eio1, NodeMsgType::channel_msg).data())); + REQUIRE(channels1.recv_channel_message( + nid2, get_first(eio2, NodeMsgType::channel_msg).data())); + REQUIRE(channels2.recv_channel_message( + nid1, get_first(eio1, NodeMsgType::channel_msg).data())); + } + else + { + INFO("Node 2 attempts to connect"); + channels2.send_authenticated( + nid1, NodeMsgType::consensus_msg, msg.data(), msg.size()); + + REQUIRE(channels1.recv_channel_message( + nid2, get_first(eio2, NodeMsgType::channel_msg).data())); + REQUIRE(channels2.recv_channel_message( + nid1, get_first(eio1, NodeMsgType::channel_msg).data())); + REQUIRE(channels1.recv_channel_message( + nid2, get_first(eio2, NodeMsgType::channel_msg).data())); + } + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == ESTABLISHED); + + MsgType aad; + aad.fill(0x10); + + REQUIRE(channels1.send_encrypted( + nid2, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, msg)); + auto msg1 = get_first(eio1, NodeMsgType::consensus_msg); + auto decrypted1 = channels2.recv_encrypted( + nid1, + {msg1.authenticated_hdr.data(), msg1.authenticated_hdr.size()}, + msg1.payload.data(), + msg1.payload.size()); + REQUIRE(decrypted1 == msg); + + REQUIRE(channels2.send_encrypted( + nid1, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, msg)); + auto msg2 = get_first(eio2, NodeMsgType::consensus_msg); + auto decrypted2 = channels1.recv_encrypted( + nid2, + {msg2.authenticated_hdr.data(), msg2.authenticated_hdr.size()}, + msg2.payload.data(), + msg2.payload.size()); + REQUIRE(decrypted2 == msg); + } + } +} + +TEST_CASE("Robust key exchange") +{ + auto network_kp = crypto::make_key_pair(default_curve); + auto network_cert = network_kp->self_sign("CN=Network"); + + auto channel1_kp = crypto::make_key_pair(default_curve); + auto channel1_csr = channel1_kp->create_csr("CN=Node1"); + auto channel1_cert = network_kp->sign_csr(network_cert, channel1_csr, {}); + + auto channel2_kp = crypto::make_key_pair(default_curve); + auto channel2_csr = channel2_kp->create_csr("CN=Node2"); + auto channel2_cert = network_kp->sign_csr(network_cert, channel2_csr, {}); + + const NodeId nid3 = std::string("nid3"); + + auto channels1 = NodeToNodeChannelManager(wf1); + channels1.initialize(nid1, network_cert, channel1_kp, channel1_cert); + auto channels2 = NodeToNodeChannelManager(wf2); + channels2.initialize(nid2, network_cert, channel2_kp, channel2_cert); + + MsgType aad; + aad.fill(0x10); + + std::vector payload; + payload.push_back(0x1); + payload.push_back(0x0); + payload.push_back(0x10); + payload.push_back(0x42); + + std::vector>> + old_messages; + { + INFO("Build a collection of old messages that could confuse the protocol"); + + channels1.send_encrypted( + nid2, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload); + channels1.send_encrypted( + nid3, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload); + channels1.send_encrypted( + nid2, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload); + + auto outbound = read_outbound_msgs(eio1); + REQUIRE(outbound.size() >= 2); + for (size_t i = 0; i < outbound.size(); ++i) + { + const auto& msg = outbound[i]; + old_messages.push_back(std::make_tuple("too-early junk", i, msg.data())); + } + + channels1.close_channel(nid2); + channels1.close_channel(nid3); + channels1.send_encrypted( + nid2, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload); + channels1.send_encrypted( + nid3, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload); + channels1.send_encrypted( + nid2, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload); + + outbound = read_outbound_msgs(eio1); + REQUIRE(outbound.size() >= 1); + auto kex_init = outbound.back(); + REQUIRE(kex_init.type == NodeMsgType::channel_msg); + for (size_t i = 0; i < outbound.size(); ++i) + { + const auto& msg = outbound[i]; + old_messages.push_back(std::make_tuple("initiation junk", i, msg.data())); + } + + channels2.send_encrypted( + nid1, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload); + channels2.send_encrypted( + nid1, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload); + channels2.send_encrypted( + nid3, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload); + + outbound = read_outbound_msgs(eio2); + for (size_t i = 0; i < outbound.size(); ++i) + { + const auto& msg = outbound[i]; + old_messages.push_back( + std::make_tuple("counter-initiation junk", i, msg.data())); + } + + // Close attempted init, so we accept (lower priority) incoming init + channels2.close_channel(nid1); + + REQUIRE(channels2.recv_channel_message(nid1, kex_init.data())); + // Replaying an init is fine, equivalent to making a new attempt + // NB: Node 2 is now working with the _second_ exchange attempt, so to + // succeed we must deliver that instance + REQUIRE(channels2.recv_channel_message(nid1, kex_init.data())); + + outbound = read_outbound_msgs(eio2); + REQUIRE(outbound.size() >= 2); + auto kex_response = outbound.back(); + REQUIRE(kex_response.type == NodeMsgType::channel_msg); + for (size_t i = 0; i < outbound.size(); ++i) + { + const auto& msg = outbound[i]; + old_messages.push_back(std::make_tuple("response junk", i, msg.data())); + } + + REQUIRE(channels1.recv_channel_message(nid2, kex_response.data())); + REQUIRE_FALSE(channels1.recv_channel_message(nid2, kex_response.data())); + + outbound = read_outbound_msgs(eio1); + REQUIRE(outbound.size() == 2); + auto kex_final = outbound[0]; + REQUIRE(kex_final.type == NodeMsgType::channel_msg); + REQUIRE(outbound[1].type == NodeMsgType::consensus_msg); + for (size_t i = 0; i < outbound.size(); ++i) + { + const auto& msg = outbound[i]; + old_messages.push_back(std::make_tuple("final junk", i, msg.data())); + } + + REQUIRE(channels2.recv_channel_message(nid1, kex_final.data())); + REQUIRE_FALSE(channels2.recv_channel_message(nid1, kex_final.data())); + + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == ESTABLISHED); + + REQUIRE(channels1.send_encrypted( + nid2, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload)); + channels1.send_encrypted( + nid3, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload); + REQUIRE(channels2.send_encrypted( + nid1, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload)); + channels2.send_encrypted( + nid3, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload); + + outbound = read_outbound_msgs(eio1); + for (size_t i = 0; i < outbound.size(); ++i) + { + const auto& msg = outbound[i]; + old_messages.push_back(std::make_tuple("tailing junk A", i, msg.data())); + } + + outbound = read_outbound_msgs(eio2); + for (size_t i = 0; i < outbound.size(); ++i) + { + const auto& msg = outbound[i]; + old_messages.push_back(std::make_tuple("tailing junk B", i, msg.data())); + } + } + + channels1.close_channel(nid2); + channels2.close_channel(nid1); + + { + INFO("Mix key exchange with old messages"); + + auto receive_junk = [&]() { + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(old_messages.begin(), old_messages.end(), g); + + for (const auto& [label, i, msg] : old_messages) + { + // Uncomment this line to aid debugging if any of these fail + // std::cout << label << ": " << i << std::endl; + auto msg_1 = msg; + channels1.recv_channel_message(nid2, std::move(msg_1)); + auto msg_2 = msg; + channels2.recv_channel_message(nid1, std::move(msg_2)); + + // Remove anything they responded with from the ringbuffer + read_outbound_msgs(eio1); + read_outbound_msgs(eio2); + } + }; + + receive_junk(); + + channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, payload.data(), payload.size()); + + receive_junk(); + + channels1.close_channel(nid2); + channels2.close_channel(nid1); + + channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, payload.data(), payload.size()); + auto kex_init = get_first(eio1, NodeMsgType::channel_msg); + + REQUIRE(channels2.recv_channel_message(nid1, kex_init.data())); + + receive_junk(); + + channels1.close_channel(nid2); + channels2.close_channel(nid1); + + channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, payload.data(), payload.size()); + kex_init = get_first(eio1, NodeMsgType::channel_msg); + REQUIRE(channels2.recv_channel_message(nid1, kex_init.data())); + auto kex_response = get_first(eio2, NodeMsgType::channel_msg); + + REQUIRE(channels1.recv_channel_message(nid2, kex_response.data())); + + receive_junk(); + + channels1.close_channel(nid2); + channels2.close_channel(nid1); + + channels1.send_authenticated( + nid2, NodeMsgType::consensus_msg, payload.data(), payload.size()); + kex_init = get_first(eio1, NodeMsgType::channel_msg); + REQUIRE(channels2.recv_channel_message(nid1, kex_init.data())); + kex_response = get_first(eio2, NodeMsgType::channel_msg); + + REQUIRE(channels1.recv_channel_message(nid2, kex_response.data())); + auto kex_final = get_first(eio1, NodeMsgType::channel_msg); + + REQUIRE(channels2.recv_channel_message(nid1, kex_final.data())); + + REQUIRE(channels1.get_status(nid2) == ESTABLISHED); + REQUIRE(channels2.get_status(nid1) == ESTABLISHED); + + // We are not robust to new inits here! + // receive_junk(); + + REQUIRE(channels1.send_encrypted( + nid2, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload)); + REQUIRE(channels2.send_encrypted( + nid1, NodeMsgType::consensus_msg, {aad.data(), aad.size()}, payload)); + } } \ No newline at end of file diff --git a/tests/infra/partitions.py b/tests/infra/partitions.py index 322242540ac5..139d6a76edda 100644 --- a/tests/infra/partitions.py +++ b/tests/infra/partitions.py @@ -88,6 +88,27 @@ def cleanup(): iptc.easy.delete_chain("filter", CCF_IPTABLES_CHAIN) LOG.info(f"{CCF_IPTABLES_CHAIN} iptables chain cleaned up") + @staticmethod + def reverse_rule(rule): + def swap_fields(obj, a, b): + res = {**obj} + if a in obj: + res[b] = obj[a] + else: + del res[b] + if b in obj: + res[a] = obj[b] + else: + del res[a] + return res + + r = swap_fields(rule, "src", "dst") + + if "tcp" in rule: + r["tcp"] = swap_fields(rule["tcp"], "sport", "dport") + + return r + def __init__(self, network): self.network = network @@ -140,17 +161,22 @@ def isolate_node( client_rule["dst"] = other.node_host name += f" from node {other.local_node_id}" - if iptc.easy.has_rule("filter", CCF_IPTABLES_CHAIN, server_rule): - iptc.easy.delete_rule("filter", CCF_IPTABLES_CHAIN, server_rule) + rules = [ + server_rule, + self.reverse_rule(server_rule), + client_rule, + self.reverse_rule(client_rule), + ] - if iptc.easy.has_rule("filter", CCF_IPTABLES_CHAIN, client_rule): - iptc.easy.delete_rule("filter", CCF_IPTABLES_CHAIN, client_rule) + for rule in rules: + if iptc.easy.has_rule("filter", CCF_IPTABLES_CHAIN, rule): + iptc.easy.delete_rule("filter", CCF_IPTABLES_CHAIN, rule) - iptc.easy.insert_rule("filter", CCF_IPTABLES_CHAIN, server_rule) - iptc.easy.insert_rule("filter", CCF_IPTABLES_CHAIN, client_rule) + iptc.easy.insert_rule("filter", CCF_IPTABLES_CHAIN, rule) LOG.debug(name) - return Rules([server_rule, client_rule], name) + + return Rules(rules, name) @staticmethod def _get_partition_name(partition: List[infra.node.Node]): diff --git a/tests/partitions_test.py b/tests/partitions_test.py index e6d847dbf7d4..0cc2901a1a56 100644 --- a/tests/partitions_test.py +++ b/tests/partitions_test.py @@ -104,7 +104,7 @@ def test_isolate_primary_from_one_backup(network, args): @reqs.description("Isolate and reconnect primary") -def test_isolate_and_reconnect_primary(network, args): +def test_isolate_and_reconnect_primary(network, args, **kwargs): primary, backups = network.find_nodes() with network.partitioner.partition(backups): lost_tx_resp = check_does_not_progress(primary) @@ -147,8 +147,8 @@ def run(args): test_invalid_partitions(network, args) test_partition_majority(network, args) test_isolate_primary_from_one_backup(network, args) - for _ in range(5): - test_isolate_and_reconnect_primary(network, args) + for n in range(5): + test_isolate_and_reconnect_primary(network, args, iteration=n) if __name__ == "__main__":