Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(multiplexer): Use single channel for muxer #133

Merged
merged 2 commits into from
Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/n2n-miniprotocols/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl chainsync::Observer<chainsync::HeaderContent> for LoggingObserver {
_content: chainsync::HeaderContent,
tip: &chainsync::Tip,
) -> Result<chainsync::Continuation, Box<dyn std::error::Error>> {
log::debug!("asked to roll forward, tip at {:?}", tip);
log::info!("asked to roll forward, tip at {:?}", tip);

Ok(chainsync::Continuation::Proceed)
}
Expand Down Expand Up @@ -96,7 +96,7 @@ fn do_chainsync(mut channel: ChannelBuffer<StdChannel>) {

fn main() {
env_logger::builder()
.filter_level(log::LevelFilter::Trace)
.filter_level(log::LevelFilter::Info)
.init();

// setup a TCP socket to act as data bearer between our agents and the remote
Expand Down
60 changes: 60 additions & 0 deletions pallas-multiplexer/src/agents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,63 @@ impl<C: Channel> From<C> for ChannelBuffer<C> {
ChannelBuffer::new(channel)
}
}

#[cfg(test)]
mod tests {
use std::collections::VecDeque;

use super::*;

impl Channel for VecDeque<Payload> {
fn enqueue_chunk(&mut self, chunk: Payload) -> Result<(), ChannelError> {
self.push_back(chunk);
Ok(())
}

fn dequeue_chunk(&mut self) -> Result<Payload, ChannelError> {
let chunk = self.pop_front().ok_or(ChannelError::NotConnected(None))?;
Ok(chunk)
}
}

#[test]
fn multiple_messages_in_same_payload() {
let mut input = Vec::new();
let in_part1 = (1u8, 2u8, 3u8);
let in_part2 = (6u8, 5u8, 4u8);

minicbor::encode(in_part1, &mut input).unwrap();
minicbor::encode(in_part2, &mut input).unwrap();

let mut channel = VecDeque::<Payload>::new();
channel.push_back(input);

let mut buf = ChannelBuffer::new(channel);

let out_part1 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap();
let out_part2 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap();

assert_eq!(in_part1, out_part1);
assert_eq!(in_part2, out_part2);
}

#[test]
fn fragmented_message_in_multiple_payloads() {
let mut input = Vec::new();
let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8);
minicbor::encode(msg, &mut input).unwrap();

let mut channel = VecDeque::<Payload>::new();

while !input.is_empty() {
let chunk = Vec::from(input.drain(0..2).as_slice());
channel.push_back(chunk);
}

let mut buf = ChannelBuffer::new(channel);

let out_msg = buf.recv_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>().unwrap();

assert_eq!(msg, out_msg);
}
}
29 changes: 1 addition & 28 deletions pallas-multiplexer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ pub mod bearers;
pub mod demux;
pub mod mux;

use bearers::Bearer;

#[cfg(feature = "std")]
mod std;

Expand All @@ -13,29 +11,4 @@ pub use crate::std::*;

pub type Payload = Vec<u8>;

pub struct Multiplexer<I, E>
where
I: mux::Ingress,
E: demux::Egress,
{
pub muxer: mux::Muxer<I>,
pub demuxer: demux::Demuxer<E>,
}

impl<I, E> Multiplexer<I, E>
where
I: mux::Ingress,
E: demux::Egress,
{
pub fn new(bearer: Bearer) -> Self {
Multiplexer {
muxer: mux::Muxer::new(bearer.clone()),
demuxer: demux::Demuxer::new(bearer),
}
}

pub fn register_channel(&mut self, protocol: u16, ingress: I, egress: E) {
self.muxer.register(protocol, ingress);
self.demuxer.register(protocol, egress);
}
}
pub type Message = (u16, Payload);
66 changes: 11 additions & 55 deletions pallas-multiplexer/src/mux.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use std::{collections::HashMap, time::Instant};

use rand::seq::SliceRandom;
use rand::thread_rng;
use std::time::{Duration, Instant};

use crate::{
bearers::{Bearer, Segment},
Payload,
Message,
};

pub enum IngressError {
Expand All @@ -18,87 +15,46 @@ pub enum IngressError {
/// To be implemented by any mechanism that allows to submit a payloads from a
/// particular protocol that need to be muxed by the multiplexer.
pub trait Ingress {
fn try_recv(&mut self) -> Result<Payload, IngressError>;
fn recv_timeout(&mut self, duration: Duration) -> Result<Message, IngressError>;
}

type Message = (u16, Payload);

pub enum TickOutcome {
BearerError(std::io::Error),
IngressDisconnected,
Idle,
Busy,
}

pub struct Muxer<I> {
bearer: Bearer,
ingress: HashMap<u16, I>,
ingress: I,
clock: Instant,
}

impl<I> Muxer<I>
where
I: Ingress,
{
pub fn new(bearer: Bearer) -> Self {
pub fn new(bearer: Bearer, ingress: I) -> Self {
Self {
bearer,
ingress: Default::default(),
ingress,
clock: Instant::now(),
}
}

/// Register the receiver end of an ingress channel
pub fn register(&mut self, id: u16, rx: I) {
self.ingress.insert(id, rx);
}

/// Remove a protocol from the ingress
///
/// Meant to be used after a receive error in a previous tick
pub fn deregister(&mut self, id: u16) {
self.ingress.remove(&id);
}

#[inline]
fn randomize_ids(&self) -> Vec<u16> {
let mut rng = thread_rng();
let mut keys: Vec<_> = self.ingress.keys().cloned().collect();
keys.shuffle(&mut rng);
keys
}

/// Select the next segment to be muxed
///
/// This method iterates over the existing receivers checking for the first
/// available message. The order of the checks is random to ensure a fair
/// use of the multiplexer amongst all protocols.
pub fn select(&mut self) -> Option<Message> {
for id in self.randomize_ids() {
let rx = self.ingress.get_mut(&id).unwrap();

match rx.try_recv() {
Ok(payload) => return Some((id, payload)),
Err(IngressError::Disconnected) => {
self.deregister(id);
}
_ => (),
};
}

None
}

pub fn tick(&mut self) -> TickOutcome {
match self.select() {
Some((id, payload)) => {
match self.ingress.recv_timeout(Duration::from_millis(500)) {
Ok((id, payload)) => {
let segment = Segment::new(self.clock, id, payload);

match self.bearer.write_segment(segment) {
Err(err) => TickOutcome::BearerError(err),
_ => TickOutcome::Busy,
}
}
None => TickOutcome::Idle,
Err(IngressError::Empty) => TickOutcome::Idle,
Err(IngressError::Disconnected) => TickOutcome::IngressDisconnected,
}
}
}
51 changes: 32 additions & 19 deletions pallas-multiplexer/src/std.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
use crate::{
agents::{self, ChannelBuffer},
demux, mux, Payload,
bearers::Bearer,
demux, mux, Message, Payload,
};

use std::{
sync::{
atomic::{AtomicBool, Ordering},
mpsc::{channel, Receiver, SendError, Sender, TryRecvError},
mpsc::{channel, Receiver, RecvTimeoutError, SendError, Sender},
Arc,
},
thread::{spawn, JoinHandle},
time::Duration,
};

pub type StdIngress = Receiver<Payload>;
pub type StdIngress = Receiver<Message>;

impl mux::Ingress for StdIngress {
fn try_recv(&mut self) -> Result<Payload, mux::IngressError> {
match Receiver::try_recv(self) {
fn recv_timeout(&mut self, duration: Duration) -> Result<Message, mux::IngressError> {
match Receiver::recv_timeout(self, duration) {
Ok(x) => Ok(x),
Err(TryRecvError::Disconnected) => Err(mux::IngressError::Disconnected),
Err(TryRecvError::Empty) => Err(mux::IngressError::Empty),
Err(RecvTimeoutError::Disconnected) => Err(mux::IngressError::Disconnected),
Err(RecvTimeoutError::Timeout) => Err(mux::IngressError::Empty),
}
}
}
Expand All @@ -36,16 +37,30 @@ impl demux::Egress for StdEgress {
}
}

pub type StdPlexer = crate::Multiplexer<StdIngress, StdEgress>;
pub struct StdPlexer {
pub muxer: mux::Muxer<StdIngress>,
pub demuxer: demux::Demuxer<StdEgress>,
pub mux_tx: Sender<Message>,
}

impl StdPlexer {
pub fn new(bearer: Bearer) -> Self {
let (mux_tx, mux_rx) = channel::<Message>();

Self {
muxer: mux::Muxer::new(bearer.clone(), mux_rx),
demuxer: demux::Demuxer::new(bearer),
mux_tx,
}
}

pub fn use_channel(&mut self, protocol: u16) -> StdChannel {
let (demux_tx, demux_rx) = channel::<Payload>();
let (mux_tx, mux_rx) = channel::<Payload>();
self.demuxer.register(protocol, demux_tx);

self.register_channel(protocol, mux_rx, demux_tx);
let mux_tx = self.mux_tx.clone();

(mux_tx, demux_rx)
(protocol, mux_tx, demux_rx)
}
}

Expand All @@ -56,12 +71,10 @@ impl mux::Muxer<StdIngress> {
mux::TickOutcome::BearerError(err) => return Err(err),
mux::TickOutcome::Idle => match cancel.is_set() {
true => break Ok(()),
false => {
// TODO: investigate why std::thread::yield_now() hogs the thread
std::thread::sleep(Duration::from_millis(100))
}
false => (),
},
mux::TickOutcome::Busy => (),
mux::TickOutcome::IngressDisconnected => break Ok(()),
}
}
}
Expand Down Expand Up @@ -104,20 +117,20 @@ impl demux::Demuxer<StdEgress> {
}
}

pub type StdChannel = (Sender<Payload>, Receiver<Payload>);
pub type StdChannel = (u16, Sender<Message>, Receiver<Payload>);

pub type StdChannelBuffer = ChannelBuffer<StdChannel>;

impl agents::Channel for StdChannel {
fn enqueue_chunk(&mut self, payload: Payload) -> Result<(), agents::ChannelError> {
match self.0.send(payload) {
match self.1.send((self.0, payload)) {
Ok(_) => Ok(()),
Err(SendError(payload)) => Err(agents::ChannelError::NotConnected(Some(payload))),
Err(SendError((_, payload))) => Err(agents::ChannelError::NotConnected(Some(payload))),
}
}

fn dequeue_chunk(&mut self) -> Result<Payload, agents::ChannelError> {
match self.1.recv() {
match self.2.recv() {
Ok(payload) => Ok(payload),
Err(_) => Err(agents::ChannelError::NotConnected(None)),
}
Expand Down
43 changes: 1 addition & 42 deletions pallas-multiplexer/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use pallas_codec::minicbor;
use pallas_multiplexer::{
agents::{Channel, ChannelBuffer},
bearers::Bearer,
StdPlexer,
Payload, StdPlexer,
};
use rand::{distributions::Uniform, Rng};

Expand Down Expand Up @@ -62,44 +62,3 @@ fn one_way_small_sequence_of_payloads() {
assert_eq!(payload, received_payload);
}
}

#[test]
fn multiple_messages_in_same_payload() {
let mut input = Vec::new();
let in_part1 = (1u8, 2u8, 3u8);
let in_part2 = (6u8, 5u8, 4u8);

minicbor::encode(in_part1, &mut input).unwrap();
minicbor::encode(in_part2, &mut input).unwrap();

let channel = std::sync::mpsc::channel();
channel.0.send(input).unwrap();

let mut buf = ChannelBuffer::new(channel);

let out_part1 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap();
let out_part2 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap();

assert_eq!(in_part1, out_part1);
assert_eq!(in_part2, out_part2);
}

#[test]
fn fragmented_message_in_multiple_payloads() {
let mut input = Vec::new();
let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8);
minicbor::encode(msg, &mut input).unwrap();

let channel = std::sync::mpsc::channel();

while !input.is_empty() {
let chunk = Vec::from(input.drain(0..2).as_slice());
channel.0.send(chunk).unwrap();
}

let mut buf = ChannelBuffer::new(channel);

let out_msg = buf.recv_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>().unwrap();

assert_eq!(msg, out_msg);
}