Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Handshake with query for n2c #266

Merged
merged 1 commit into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/n2c-miniprotocols/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ async fn main() {

// we connect to the unix socket of the local node. Make sure you have the right
// path for your environment
let mut client = NodeClient::connect("/tmp/node.socket", MAINNET_MAGIC)
let socket_path = "/tmp/node.socket";

// we connect to the unix socket of the local node and perform a handshake query
let version_table = NodeClient::handshake_query(socket_path, MAINNET_MAGIC)
.await
.unwrap();
info!("handshake query result: {:?}", version_table);

let mut client = NodeClient::connect(socket_path, MAINNET_MAGIC)
.await
.unwrap();

Expand Down
39 changes: 39 additions & 0 deletions pallas-network/src/facades.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::{
},
multiplexer::{self, Bearer},
};
use crate::miniprotocols::handshake::Confirmation;

#[derive(Debug, Error)]
pub enum Error {
Expand Down Expand Up @@ -126,6 +127,44 @@ impl NodeClient {
})
}

#[cfg(not(target_os = "windows"))]
pub async fn handshake_query(path: impl AsRef<Path>, magic: u64) -> Result<handshake::n2c::VersionTable, Error> {
debug!("connecting");

let bearer = Bearer::connect_unix(path)
.await
.map_err(Error::ConnectFailure)?;

let mut plexer = multiplexer::Plexer::new(bearer);

let hs_channel = plexer.subscribe_client(PROTOCOL_N2C_HANDSHAKE);

let plexer_handle = tokio::spawn(async move { plexer.run().await });

let versions = handshake::n2c::VersionTable::v15_with_query(magic);
let mut client = handshake::Client::new(hs_channel);

let handshake = client
.handshake(versions)
.await
.map_err(Error::HandshakeProtocol)?;

match handshake {
Confirmation::Accepted(_, _) => {
error!("handshake accepted when we expected query reply");
Err(Error::HandshakeProtocol(handshake::Error::InvalidInbound))
}
Confirmation::Rejected(reason) => {
error!(?reason, "handshake refused");
Err(Error::IncompatibleVersion)
}
Confirmation::QueryReply(version_table) => {
plexer_handle.abort();
Ok(version_table)
}
}
}

pub fn chainsync(&mut self) -> &mut chainsync::N2CClient {
&mut self.chainsync
}
Expand Down
12 changes: 10 additions & 2 deletions pallas-network/src/miniprotocols/handshake/client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::fmt::Debug;
use pallas_codec::Fragment;
use std::marker::PhantomData;
use tracing::debug;
Expand All @@ -6,16 +7,17 @@ use super::{Error, Message, RefuseReason, State, VersionNumber, VersionTable};
use crate::multiplexer;

#[derive(Debug)]
pub enum Confirmation<D> {
pub enum Confirmation<D: Debug + Clone> {
Accepted(VersionNumber, D),
Rejected(RefuseReason),
QueryReply(VersionTable<D>),
}

pub struct Client<D>(State, multiplexer::ChannelBuffer, PhantomData<D>);

impl<D> Client<D>
where
D: std::fmt::Debug + Clone,
D: Debug + Clone,
Message<D>: Fragment,
{
pub fn new(channel: multiplexer::AgentChannel) -> Self {
Expand Down Expand Up @@ -69,6 +71,7 @@ where
match (&self.0, msg) {
(State::Confirm, Message::Accept(..)) => Ok(()),
(State::Confirm, Message::Refuse(..)) => Ok(()),
(State::Confirm, Message::QueryReply(..)) => Ok(()),
_ => Err(Error::InvalidInbound),
}
}
Expand Down Expand Up @@ -113,6 +116,11 @@ where

Ok(Confirmation::Rejected(r))
}
Message::QueryReply(version_table) => {
debug!("handshake query reply");

Ok(Confirmation::QueryReply(version_table))
}
_ => Err(Error::InvalidInbound),
}
}
Expand Down
83 changes: 61 additions & 22 deletions pallas-network/src/miniprotocols/handshake/n2c.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;

use pallas_codec::minicbor::{decode, encode, Decode, Decoder, Encode, Encoder};
use pallas_codec::minicbor::{decode, Decode, Decoder, encode, Encode, Encoder};
use pallas_codec::minicbor::data::Type;

use super::protocol::NetworkMagic;

Expand All @@ -18,22 +19,28 @@ const PROTOCOL_V9: u64 = 32777;
const PROTOCOL_V10: u64 = 32778;
const PROTOCOL_V11: u64 = 32779;
const PROTOCOL_V12: u64 = 32780;
const PROTOCOL_V13: u64 = 32781;
const PROTOCOL_V14: u64 = 32782;
const PROTOCOL_V15: u64 = 32783;

impl VersionTable {
pub fn v1_and_above(network_magic: u64) -> VersionTable {
let values = vec![
(PROTOCOL_V1, VersionData(network_magic)),
(PROTOCOL_V2, VersionData(network_magic)),
(PROTOCOL_V3, VersionData(network_magic)),
(PROTOCOL_V4, VersionData(network_magic)),
(PROTOCOL_V5, VersionData(network_magic)),
(PROTOCOL_V6, VersionData(network_magic)),
(PROTOCOL_V7, VersionData(network_magic)),
(PROTOCOL_V8, VersionData(network_magic)),
(PROTOCOL_V9, VersionData(network_magic)),
(PROTOCOL_V10, VersionData(network_magic)),
(PROTOCOL_V11, VersionData(network_magic)),
(PROTOCOL_V12, VersionData(network_magic)),
(PROTOCOL_V1, VersionData(network_magic, None)),
(PROTOCOL_V2, VersionData(network_magic, None)),
(PROTOCOL_V3, VersionData(network_magic, None)),
(PROTOCOL_V4, VersionData(network_magic, None)),
(PROTOCOL_V5, VersionData(network_magic, None)),
(PROTOCOL_V6, VersionData(network_magic, None)),
(PROTOCOL_V7, VersionData(network_magic, None)),
(PROTOCOL_V8, VersionData(network_magic, None)),
(PROTOCOL_V9, VersionData(network_magic, None)),
(PROTOCOL_V10, VersionData(network_magic, None)),
(PROTOCOL_V11, VersionData(network_magic, None)),
(PROTOCOL_V12, VersionData(network_magic, None)),
(PROTOCOL_V13, VersionData(network_magic, None)),
(PROTOCOL_V14, VersionData(network_magic, None)),
(PROTOCOL_V15, VersionData(network_magic, Some(false))),
]
.into_iter()
.collect::<HashMap<u64, VersionData>>();
Expand All @@ -42,7 +49,7 @@ impl VersionTable {
}

pub fn only_v10(network_magic: u64) -> VersionTable {
let values = vec![(PROTOCOL_V10, VersionData(network_magic))]
let values = vec![(PROTOCOL_V10, VersionData(network_magic, None))]
.into_iter()
.collect::<HashMap<u64, VersionData>>();

Expand All @@ -51,9 +58,22 @@ impl VersionTable {

pub fn v10_and_above(network_magic: u64) -> VersionTable {
let values = vec![
(PROTOCOL_V10, VersionData(network_magic)),
(PROTOCOL_V11, VersionData(network_magic)),
(PROTOCOL_V12, VersionData(network_magic)),
(PROTOCOL_V10, VersionData(network_magic, None)),
(PROTOCOL_V11, VersionData(network_magic, None)),
(PROTOCOL_V12, VersionData(network_magic, None)),
(PROTOCOL_V13, VersionData(network_magic, None)),
(PROTOCOL_V14, VersionData(network_magic, None)),
(PROTOCOL_V15, VersionData(network_magic, Some(false))),
]
.into_iter()
.collect::<HashMap<u64, VersionData>>();

VersionTable { values }
}

pub fn v15_with_query(network_magic: u64) -> VersionTable {
let values = vec![
(PROTOCOL_V15, VersionData(network_magic, Some(true))),
]
.into_iter()
.collect::<HashMap<u64, VersionData>>();
Expand All @@ -63,24 +83,43 @@ impl VersionTable {
}

#[derive(Debug, Clone)]
pub struct VersionData(NetworkMagic);
pub struct VersionData(NetworkMagic, Option<bool>);

impl Encode<()> for VersionData {
fn encode<W: encode::Write>(
&self,
e: &mut Encoder<W>,
_ctx: &mut (),
) -> Result<(), encode::Error<W::Error>> {
e.u64(self.0)?;
match self.1 {
None => { e.u64(self.0)?; }
Some(is_query) => {
e.array(2)?;
e.u64(self.0)?;
e.bool(is_query)?;
}
}

Ok(())
}
}

impl<'b> Decode<'b, ()> for VersionData {
fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
let network_magic = d.u64()?;

Ok(Self(network_magic))
match d.datatype()? {
Type::U8 | Type::U16 | Type::U32 | Type::U64 => {
let network_magic = d.u64()?;
Ok(Self(network_magic, None))
}
Type::Array => {
d.array()?;
let network_magic = d.u64()?;
let is_query = d.bool()?;
Ok(Self(network_magic, Some(is_query)))
}
_ => Err(decode::Error::message(
"unknown type for VersionData",
)),
}
}
}
20 changes: 18 additions & 2 deletions pallas-network/src/miniprotocols/handshake/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,15 @@ impl<'b, T> Decode<'b, ()> for VersionTable<T>
where
T: Debug + Clone + Decode<'b, ()>,
{
fn decode(d: &mut Decoder<'b>, ctx: &mut ()) -> Result<Self, decode::Error> {
let values = d.map_iter_with(ctx)?.collect::<Result<_, _>>()?;
fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
let len = d.map()?.ok_or(decode::Error::message("expected def-length map for versiontable"))?;
let mut values = HashMap::new();

for _ in 0..len {
let key = d.u64()?;
let value = d.decode()?;
values.insert(key, value);
}
Ok(VersionTable { values })
}
}
Expand All @@ -73,6 +80,7 @@ where
Propose(VersionTable<D>),
Accept(VersionNumber, D),
Refuse(RefuseReason),
QueryReply(VersionTable<D>),
}

impl<D> Encode<()> for Message<D>
Expand Down Expand Up @@ -100,6 +108,10 @@ where
e.array(2)?.u16(2)?;
e.encode(reason)?;
}
Message::QueryReply(version_table) => {
e.array(2)?.u16(3)?;
e.encode(version_table)?;
}
};

Ok(())
Expand Down Expand Up @@ -128,6 +140,10 @@ where
let reason: RefuseReason = d.decode()?;
Ok(Message::Refuse(reason))
}
3 => {
let version_table = d.decode()?;
Ok(Message::QueryReply(version_table))
}
_ => Err(decode::Error::message(
"unknown variant for handshake message",
)),
Expand Down
2 changes: 1 addition & 1 deletion pallas-network/src/multiplexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tokio::time::Instant;
use tracing::{debug, error, trace};

#[cfg(not(target_os = "windows"))]
use UnixStream;
use tokio::net::UnixStream;

const HEADER_LEN: usize = 8;

Expand Down