Skip to content

Commit

Permalink
Add test for poll-based API
Browse files Browse the repository at this point in the history
To reduce the duplication we split out a test harness.
  • Loading branch information
thomaseizinger committed Nov 3, 2022
1 parent 79c1479 commit 3941cdc
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 190 deletions.
70 changes: 9 additions & 61 deletions tests/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@
// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
// at https://opensource.org/licenses/MIT.

#[allow(dead_code)]
mod harness;

use futures::prelude::*;
use futures::stream::FuturesUnordered;
use harness::*;
use quickcheck::{Arbitrary, Gen, QuickCheck};
use std::{
io,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc,
};
use tokio::net::{TcpListener, TcpStream};
use tokio::{net::TcpSocket, runtime::Runtime, task};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
use yamux::{
Config, Connection, ConnectionError, Control, ControlledConnection, Mode, WindowUpdateMode,
};
use yamux::{Config, Connection, ConnectionError, Control, Mode, WindowUpdateMode};

const PAYLOAD_SIZE: usize = 128 * 1024;

Expand All @@ -30,7 +31,7 @@ fn concurrent_streams() {
let _ = env_logger::try_init();

fn prop(tcp_buffer_sizes: Option<TcpBufferSizes>) {
let data = Arc::new(vec![0x42; PAYLOAD_SIZE]);
let data = Msg(vec![0x42; PAYLOAD_SIZE]);
let n_streams = 1000;

Runtime::new().expect("new runtime").block_on(async move {
Expand All @@ -47,10 +48,11 @@ fn concurrent_streams() {
let mut ctrl = ctrl.clone();

task::spawn(async move {
let stream = ctrl.open_stream().await?;
let mut stream = ctrl.open_stream().await?;
log::debug!("C: opened new stream {}", stream.id());

send_recv_data(stream, &data).await?;
send_recv_message(&mut stream, data).await?;
stream.close().await?;

Ok::<(), ConnectionError>(())
})
Expand All @@ -71,60 +73,6 @@ fn concurrent_streams() {
QuickCheck::new().tests(3).quickcheck(prop as fn(_) -> _)
}

/// For each incoming stream of `c` echo back to the sender.
async fn echo_server<T>(mut c: Connection<T>) -> Result<(), ConnectionError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
stream::poll_fn(|cx| c.poll_next_inbound(cx))
.try_for_each_concurrent(None, |mut stream| async move {
log::debug!("S: accepted new stream");

let mut len = [0; 4];
stream.read_exact(&mut len).await?;

let mut buf = vec![0; u32::from_be_bytes(len) as usize];

stream.read_exact(&mut buf).await?;
stream.write_all(&buf).await?;
stream.close().await?;

Ok(())
})
.await
}

/// For each incoming stream, do nothing.
async fn noop_server<T>(c: ControlledConnection<T>)
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
c.for_each(|maybe_stream| {
drop(maybe_stream);
future::ready(())
})
.await;
}

/// Sends the given data on the provided stream, length-prefixed.
async fn send_recv_data(mut stream: yamux::Stream, data: &[u8]) -> io::Result<()> {
let len = (data.len() as u32).to_be_bytes();
stream.write_all(&len).await?;
stream.write_all(data).await?;
stream.close().await?;

log::debug!("C: {}: wrote {} bytes", stream.id(), data.len());

let mut received = vec![0; data.len()];
stream.read_exact(&mut received).await?;

log::debug!("C: {}: read {} bytes", stream.id(), received.len());

assert_eq!(data, &received[..]);

Ok(())
}

/// Send and receive buffer size for a TCP socket.
#[derive(Clone, Debug, Copy)]
struct TcpBufferSizes {
Expand Down
127 changes: 127 additions & 0 deletions tests/harness.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use futures::{
future, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, StreamExt, TryStreamExt,
};
use futures::{stream, Stream};
use quickcheck::{Arbitrary, Gen};
use std::io;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
use yamux::ConnectionError;
use yamux::{Config, WindowUpdateMode};
use yamux::{Connection, Mode};

pub async fn connected_peers(
server_config: Config,
client_config: Config,
) -> io::Result<(Connection<Compat<TcpStream>>, Connection<Compat<TcpStream>>)> {
let (listener, addr) = bind().await?;

let server = async {
let (stream, _) = listener.accept().await?;
Ok(Connection::new(
stream.compat(),
server_config,
Mode::Server,
))
};
let client = async {
let stream = TcpStream::connect(addr).await?;
Ok(Connection::new(
stream.compat(),
client_config,
Mode::Client,
))
};

futures::future::try_join(server, client).await
}

pub async fn bind() -> io::Result<(TcpListener, SocketAddr)> {
let i = Ipv4Addr::new(127, 0, 0, 1);
let s = SocketAddr::V4(SocketAddrV4::new(i, 0));
let l = TcpListener::bind(&s).await?;
let a = l.local_addr()?;
Ok((l, a))
}

/// For each incoming stream of `c` echo back to the sender.
pub async fn echo_server<T>(mut c: Connection<T>) -> Result<(), ConnectionError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
stream::poll_fn(|cx| c.poll_next_inbound(cx))
.try_for_each_concurrent(None, |mut stream| async move {
{
let (mut r, mut w) = AsyncReadExt::split(&mut stream);
futures::io::copy(&mut r, &mut w).await?;
}
stream.close().await?;
Ok(())
})
.await
}

/// For each incoming stream, do nothing.
pub async fn noop_server(c: impl Stream<Item = Result<yamux::Stream, yamux::ConnectionError>>) {
c.for_each(|maybe_stream| {
drop(maybe_stream);
future::ready(())
})
.await;
}

pub async fn send_recv_message(stream: &mut yamux::Stream, Msg(msg): Msg) -> io::Result<()> {
let id = stream.id();
let (mut reader, mut writer) = AsyncReadExt::split(stream);

let len = msg.len();
let write_fut = async {
writer.write_all(&msg).await.unwrap();
log::debug!("C: {}: sent {} bytes", id, len);
};
let mut data = vec![0; msg.len()];
let read_fut = async {
reader.read_exact(&mut data).await.unwrap();
log::debug!("C: {}: received {} bytes", id, data.len());
};
futures::future::join(write_fut, read_fut).await;
assert_eq!(data, msg);

Ok(())
}

#[derive(Clone, Debug)]
pub struct Msg(pub Vec<u8>);

impl Arbitrary for Msg {
fn arbitrary(g: &mut Gen) -> Msg {
let mut msg = Msg(Arbitrary::arbitrary(g));
if msg.0.is_empty() {
msg.0.push(Arbitrary::arbitrary(g));
}

msg
}

fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
Box::new(self.0.shrink().filter(|v| !v.is_empty()).map(Msg))
}
}

#[derive(Clone, Debug)]
pub struct TestConfig(pub Config);

impl Arbitrary for TestConfig {
fn arbitrary(g: &mut Gen) -> Self {
let mut c = Config::default();
c.set_window_update_mode(if bool::arbitrary(g) {
WindowUpdateMode::OnRead
} else {
WindowUpdateMode::OnReceive
});
c.set_read_after_close(Arbitrary::arbitrary(g));
c.set_receive_window(256 * 1024 + u32::arbitrary(g) % (768 * 1024));
TestConfig(c)
}
}
Loading

0 comments on commit 3941cdc

Please sign in to comment.