From 1b7ab76203821ad807d9dbac33b52c1e1663a0b6 Mon Sep 17 00:00:00 2001 From: Andrew Westberg Date: Sat, 24 Jun 2023 15:29:51 +0000 Subject: [PATCH] Add Handshake with query for n2c --- examples/n2c-miniprotocols/src/main.rs | 10 ++- pallas-network/src/facades.rs | 39 +++++++++ .../src/miniprotocols/handshake/client.rs | 12 ++- .../src/miniprotocols/handshake/n2c.rs | 83 ++++++++++++++----- .../src/miniprotocols/handshake/protocol.rs | 20 ++++- pallas-network/src/multiplexer.rs | 2 +- 6 files changed, 138 insertions(+), 28 deletions(-) diff --git a/examples/n2c-miniprotocols/src/main.rs b/examples/n2c-miniprotocols/src/main.rs index 99a07e80..4dec5f29 100644 --- a/examples/n2c-miniprotocols/src/main.rs +++ b/examples/n2c-miniprotocols/src/main.rs @@ -55,7 +55,15 @@ async fn main() { // we connect to the unix socket of the local node. Make sure you have the right // path for your environment - let mut client = NodeClient::connect("/tmp/node.socket", MAINNET_MAGIC) + let socket_path = "/tmp/node.socket"; + + // we connect to the unix socket of the local node and perform a handshake query + let version_table = NodeClient::handshake_query(socket_path, MAINNET_MAGIC) + .await + .unwrap(); + info!("handshake query result: {:?}", version_table); + + let mut client = NodeClient::connect(socket_path, MAINNET_MAGIC) .await .unwrap(); diff --git a/pallas-network/src/facades.rs b/pallas-network/src/facades.rs index 1e5fd994..4dc0ef1d 100644 --- a/pallas-network/src/facades.rs +++ b/pallas-network/src/facades.rs @@ -11,6 +11,7 @@ use crate::{ }, multiplexer::{self, Bearer}, }; +use crate::miniprotocols::handshake::Confirmation; #[derive(Debug, Error)] pub enum Error { @@ -126,6 +127,44 @@ impl NodeClient { }) } + #[cfg(not(target_os = "windows"))] + pub async fn handshake_query(path: impl AsRef, magic: u64) -> Result { + debug!("connecting"); + + let bearer = Bearer::connect_unix(path) + .await + .map_err(Error::ConnectFailure)?; + + let mut plexer = multiplexer::Plexer::new(bearer); + + let hs_channel = plexer.subscribe_client(PROTOCOL_N2C_HANDSHAKE); + + let plexer_handle = tokio::spawn(async move { plexer.run().await }); + + let versions = handshake::n2c::VersionTable::v15_with_query(magic); + let mut client = handshake::Client::new(hs_channel); + + let handshake = client + .handshake(versions) + .await + .map_err(Error::HandshakeProtocol)?; + + match handshake { + Confirmation::Accepted(_, _) => { + error!("handshake accepted when we expected query reply"); + Err(Error::HandshakeProtocol(handshake::Error::InvalidInbound)) + } + Confirmation::Rejected(reason) => { + error!(?reason, "handshake refused"); + Err(Error::IncompatibleVersion) + } + Confirmation::QueryReply(version_table) => { + plexer_handle.abort(); + Ok(version_table) + } + } + } + pub fn chainsync(&mut self) -> &mut chainsync::N2CClient { &mut self.chainsync } diff --git a/pallas-network/src/miniprotocols/handshake/client.rs b/pallas-network/src/miniprotocols/handshake/client.rs index 0a795ea1..4331d4fa 100644 --- a/pallas-network/src/miniprotocols/handshake/client.rs +++ b/pallas-network/src/miniprotocols/handshake/client.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; use pallas_codec::Fragment; use std::marker::PhantomData; use tracing::debug; @@ -6,16 +7,17 @@ use super::{Error, Message, RefuseReason, State, VersionNumber, VersionTable}; use crate::multiplexer; #[derive(Debug)] -pub enum Confirmation { +pub enum Confirmation { Accepted(VersionNumber, D), Rejected(RefuseReason), + QueryReply(VersionTable), } pub struct Client(State, multiplexer::ChannelBuffer, PhantomData); impl Client where - D: std::fmt::Debug + Clone, + D: Debug + Clone, Message: Fragment, { pub fn new(channel: multiplexer::AgentChannel) -> Self { @@ -69,6 +71,7 @@ where match (&self.0, msg) { (State::Confirm, Message::Accept(..)) => Ok(()), (State::Confirm, Message::Refuse(..)) => Ok(()), + (State::Confirm, Message::QueryReply(..)) => Ok(()), _ => Err(Error::InvalidInbound), } } @@ -113,6 +116,11 @@ where Ok(Confirmation::Rejected(r)) } + Message::QueryReply(version_table) => { + debug!("handshake query reply"); + + Ok(Confirmation::QueryReply(version_table)) + } _ => Err(Error::InvalidInbound), } } diff --git a/pallas-network/src/miniprotocols/handshake/n2c.rs b/pallas-network/src/miniprotocols/handshake/n2c.rs index 212998ea..4a147358 100644 --- a/pallas-network/src/miniprotocols/handshake/n2c.rs +++ b/pallas-network/src/miniprotocols/handshake/n2c.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; -use pallas_codec::minicbor::{decode, encode, Decode, Decoder, Encode, Encoder}; +use pallas_codec::minicbor::{decode, Decode, Decoder, encode, Encode, Encoder}; +use pallas_codec::minicbor::data::Type; use super::protocol::NetworkMagic; @@ -18,22 +19,28 @@ const PROTOCOL_V9: u64 = 32777; const PROTOCOL_V10: u64 = 32778; const PROTOCOL_V11: u64 = 32779; const PROTOCOL_V12: u64 = 32780; +const PROTOCOL_V13: u64 = 32781; +const PROTOCOL_V14: u64 = 32782; +const PROTOCOL_V15: u64 = 32783; impl VersionTable { pub fn v1_and_above(network_magic: u64) -> VersionTable { let values = vec![ - (PROTOCOL_V1, VersionData(network_magic)), - (PROTOCOL_V2, VersionData(network_magic)), - (PROTOCOL_V3, VersionData(network_magic)), - (PROTOCOL_V4, VersionData(network_magic)), - (PROTOCOL_V5, VersionData(network_magic)), - (PROTOCOL_V6, VersionData(network_magic)), - (PROTOCOL_V7, VersionData(network_magic)), - (PROTOCOL_V8, VersionData(network_magic)), - (PROTOCOL_V9, VersionData(network_magic)), - (PROTOCOL_V10, VersionData(network_magic)), - (PROTOCOL_V11, VersionData(network_magic)), - (PROTOCOL_V12, VersionData(network_magic)), + (PROTOCOL_V1, VersionData(network_magic, None)), + (PROTOCOL_V2, VersionData(network_magic, None)), + (PROTOCOL_V3, VersionData(network_magic, None)), + (PROTOCOL_V4, VersionData(network_magic, None)), + (PROTOCOL_V5, VersionData(network_magic, None)), + (PROTOCOL_V6, VersionData(network_magic, None)), + (PROTOCOL_V7, VersionData(network_magic, None)), + (PROTOCOL_V8, VersionData(network_magic, None)), + (PROTOCOL_V9, VersionData(network_magic, None)), + (PROTOCOL_V10, VersionData(network_magic, None)), + (PROTOCOL_V11, VersionData(network_magic, None)), + (PROTOCOL_V12, VersionData(network_magic, None)), + (PROTOCOL_V13, VersionData(network_magic, None)), + (PROTOCOL_V14, VersionData(network_magic, None)), + (PROTOCOL_V15, VersionData(network_magic, Some(false))), ] .into_iter() .collect::>(); @@ -42,7 +49,7 @@ impl VersionTable { } pub fn only_v10(network_magic: u64) -> VersionTable { - let values = vec![(PROTOCOL_V10, VersionData(network_magic))] + let values = vec![(PROTOCOL_V10, VersionData(network_magic, None))] .into_iter() .collect::>(); @@ -51,9 +58,22 @@ impl VersionTable { pub fn v10_and_above(network_magic: u64) -> VersionTable { let values = vec![ - (PROTOCOL_V10, VersionData(network_magic)), - (PROTOCOL_V11, VersionData(network_magic)), - (PROTOCOL_V12, VersionData(network_magic)), + (PROTOCOL_V10, VersionData(network_magic, None)), + (PROTOCOL_V11, VersionData(network_magic, None)), + (PROTOCOL_V12, VersionData(network_magic, None)), + (PROTOCOL_V13, VersionData(network_magic, None)), + (PROTOCOL_V14, VersionData(network_magic, None)), + (PROTOCOL_V15, VersionData(network_magic, Some(false))), + ] + .into_iter() + .collect::>(); + + VersionTable { values } + } + + pub fn v15_with_query(network_magic: u64) -> VersionTable { + let values = vec![ + (PROTOCOL_V15, VersionData(network_magic, Some(true))), ] .into_iter() .collect::>(); @@ -63,7 +83,7 @@ impl VersionTable { } #[derive(Debug, Clone)] -pub struct VersionData(NetworkMagic); +pub struct VersionData(NetworkMagic, Option); impl Encode<()> for VersionData { fn encode( @@ -71,7 +91,14 @@ impl Encode<()> for VersionData { e: &mut Encoder, _ctx: &mut (), ) -> Result<(), encode::Error> { - e.u64(self.0)?; + match self.1 { + None => { e.u64(self.0)?; } + Some(is_query) => { + e.array(2)?; + e.u64(self.0)?; + e.bool(is_query)?; + } + } Ok(()) } @@ -79,8 +106,20 @@ impl Encode<()> for VersionData { impl<'b> Decode<'b, ()> for VersionData { fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result { - let network_magic = d.u64()?; - - Ok(Self(network_magic)) + match d.datatype()? { + Type::U8 | Type::U16 | Type::U32 | Type::U64 => { + let network_magic = d.u64()?; + Ok(Self(network_magic, None)) + } + Type::Array => { + d.array()?; + let network_magic = d.u64()?; + let is_query = d.bool()?; + Ok(Self(network_magic, Some(is_query))) + } + _ => Err(decode::Error::message( + "unknown type for VersionData", + )), + } } } diff --git a/pallas-network/src/miniprotocols/handshake/protocol.rs b/pallas-network/src/miniprotocols/handshake/protocol.rs index 1ab2a28d..4e6b33c2 100644 --- a/pallas-network/src/miniprotocols/handshake/protocol.rs +++ b/pallas-network/src/miniprotocols/handshake/protocol.rs @@ -55,8 +55,15 @@ impl<'b, T> Decode<'b, ()> for VersionTable where T: Debug + Clone + Decode<'b, ()>, { - fn decode(d: &mut Decoder<'b>, ctx: &mut ()) -> Result { - let values = d.map_iter_with(ctx)?.collect::>()?; + fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result { + let len = d.map()?.ok_or(decode::Error::message("expected def-length map for versiontable"))?; + let mut values = HashMap::new(); + + for _ in 0..len { + let key = d.u64()?; + let value = d.decode()?; + values.insert(key, value); + } Ok(VersionTable { values }) } } @@ -73,6 +80,7 @@ where Propose(VersionTable), Accept(VersionNumber, D), Refuse(RefuseReason), + QueryReply(VersionTable), } impl Encode<()> for Message @@ -100,6 +108,10 @@ where e.array(2)?.u16(2)?; e.encode(reason)?; } + Message::QueryReply(version_table) => { + e.array(2)?.u16(3)?; + e.encode(version_table)?; + } }; Ok(()) @@ -128,6 +140,10 @@ where let reason: RefuseReason = d.decode()?; Ok(Message::Refuse(reason)) } + 3 => { + let version_table = d.decode()?; + Ok(Message::QueryReply(version_table)) + } _ => Err(decode::Error::message( "unknown variant for handshake message", )), diff --git a/pallas-network/src/multiplexer.rs b/pallas-network/src/multiplexer.rs index ec9530fd..4a293f50 100644 --- a/pallas-network/src/multiplexer.rs +++ b/pallas-network/src/multiplexer.rs @@ -13,7 +13,7 @@ use tokio::time::Instant; use tracing::{debug, error, trace}; #[cfg(not(target_os = "windows"))] -use UnixStream; +use tokio::net::UnixStream; const HEADER_LEN: usize = 8;