From 3941cdc620e19cc08ba588276b5f1809324c6126 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 25 Oct 2022 16:04:22 +1100 Subject: [PATCH] Add test for poll-based API To reduce the duplication we split out a test harness. --- tests/concurrent.rs | 70 ++-------------- tests/harness.rs | 127 ++++++++++++++++++++++++++++ tests/poll_api.rs | 198 ++++++++++++++++++++++++++++++++++++++++++++ tests/tests.rs | 137 ++---------------------------- 4 files changed, 342 insertions(+), 190 deletions(-) create mode 100644 tests/harness.rs create mode 100644 tests/poll_api.rs diff --git a/tests/concurrent.rs b/tests/concurrent.rs index 23ea2d6e..4122441d 100644 --- a/tests/concurrent.rs +++ b/tests/concurrent.rs @@ -8,20 +8,21 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. +#[allow(dead_code)] +mod harness; + use futures::prelude::*; use futures::stream::FuturesUnordered; +use harness::*; use quickcheck::{Arbitrary, Gen, QuickCheck}; use std::{ io, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - sync::Arc, }; use tokio::net::{TcpListener, TcpStream}; use tokio::{net::TcpSocket, runtime::Runtime, task}; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; -use yamux::{ - Config, Connection, ConnectionError, Control, ControlledConnection, Mode, WindowUpdateMode, -}; +use yamux::{Config, Connection, ConnectionError, Control, Mode, WindowUpdateMode}; const PAYLOAD_SIZE: usize = 128 * 1024; @@ -30,7 +31,7 @@ fn concurrent_streams() { let _ = env_logger::try_init(); fn prop(tcp_buffer_sizes: Option) { - let data = Arc::new(vec![0x42; PAYLOAD_SIZE]); + let data = Msg(vec![0x42; PAYLOAD_SIZE]); let n_streams = 1000; Runtime::new().expect("new runtime").block_on(async move { @@ -47,10 +48,11 @@ fn concurrent_streams() { let mut ctrl = ctrl.clone(); task::spawn(async move { - let stream = ctrl.open_stream().await?; + let mut stream = ctrl.open_stream().await?; log::debug!("C: opened new stream {}", stream.id()); - send_recv_data(stream, &data).await?; + send_recv_message(&mut stream, data).await?; + stream.close().await?; Ok::<(), ConnectionError>(()) }) @@ -71,60 +73,6 @@ fn concurrent_streams() { QuickCheck::new().tests(3).quickcheck(prop as fn(_) -> _) } -/// For each incoming stream of `c` echo back to the sender. -async fn echo_server(mut c: Connection) -> Result<(), ConnectionError> -where - T: AsyncRead + AsyncWrite + Unpin, -{ - stream::poll_fn(|cx| c.poll_next_inbound(cx)) - .try_for_each_concurrent(None, |mut stream| async move { - log::debug!("S: accepted new stream"); - - let mut len = [0; 4]; - stream.read_exact(&mut len).await?; - - let mut buf = vec![0; u32::from_be_bytes(len) as usize]; - - stream.read_exact(&mut buf).await?; - stream.write_all(&buf).await?; - stream.close().await?; - - Ok(()) - }) - .await -} - -/// For each incoming stream, do nothing. -async fn noop_server(c: ControlledConnection) -where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - c.for_each(|maybe_stream| { - drop(maybe_stream); - future::ready(()) - }) - .await; -} - -/// Sends the given data on the provided stream, length-prefixed. -async fn send_recv_data(mut stream: yamux::Stream, data: &[u8]) -> io::Result<()> { - let len = (data.len() as u32).to_be_bytes(); - stream.write_all(&len).await?; - stream.write_all(data).await?; - stream.close().await?; - - log::debug!("C: {}: wrote {} bytes", stream.id(), data.len()); - - let mut received = vec![0; data.len()]; - stream.read_exact(&mut received).await?; - - log::debug!("C: {}: read {} bytes", stream.id(), received.len()); - - assert_eq!(data, &received[..]); - - Ok(()) -} - /// Send and receive buffer size for a TCP socket. #[derive(Clone, Debug, Copy)] struct TcpBufferSizes { diff --git a/tests/harness.rs b/tests/harness.rs new file mode 100644 index 00000000..9a0bdf2a --- /dev/null +++ b/tests/harness.rs @@ -0,0 +1,127 @@ +use futures::{ + future, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, StreamExt, TryStreamExt, +}; +use futures::{stream, Stream}; +use quickcheck::{Arbitrary, Gen}; +use std::io; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; +use yamux::ConnectionError; +use yamux::{Config, WindowUpdateMode}; +use yamux::{Connection, Mode}; + +pub async fn connected_peers( + server_config: Config, + client_config: Config, +) -> io::Result<(Connection>, Connection>)> { + let (listener, addr) = bind().await?; + + let server = async { + let (stream, _) = listener.accept().await?; + Ok(Connection::new( + stream.compat(), + server_config, + Mode::Server, + )) + }; + let client = async { + let stream = TcpStream::connect(addr).await?; + Ok(Connection::new( + stream.compat(), + client_config, + Mode::Client, + )) + }; + + futures::future::try_join(server, client).await +} + +pub async fn bind() -> io::Result<(TcpListener, SocketAddr)> { + let i = Ipv4Addr::new(127, 0, 0, 1); + let s = SocketAddr::V4(SocketAddrV4::new(i, 0)); + let l = TcpListener::bind(&s).await?; + let a = l.local_addr()?; + Ok((l, a)) +} + +/// For each incoming stream of `c` echo back to the sender. +pub async fn echo_server(mut c: Connection) -> Result<(), ConnectionError> +where + T: AsyncRead + AsyncWrite + Unpin, +{ + stream::poll_fn(|cx| c.poll_next_inbound(cx)) + .try_for_each_concurrent(None, |mut stream| async move { + { + let (mut r, mut w) = AsyncReadExt::split(&mut stream); + futures::io::copy(&mut r, &mut w).await?; + } + stream.close().await?; + Ok(()) + }) + .await +} + +/// For each incoming stream, do nothing. +pub async fn noop_server(c: impl Stream>) { + c.for_each(|maybe_stream| { + drop(maybe_stream); + future::ready(()) + }) + .await; +} + +pub async fn send_recv_message(stream: &mut yamux::Stream, Msg(msg): Msg) -> io::Result<()> { + let id = stream.id(); + let (mut reader, mut writer) = AsyncReadExt::split(stream); + + let len = msg.len(); + let write_fut = async { + writer.write_all(&msg).await.unwrap(); + log::debug!("C: {}: sent {} bytes", id, len); + }; + let mut data = vec![0; msg.len()]; + let read_fut = async { + reader.read_exact(&mut data).await.unwrap(); + log::debug!("C: {}: received {} bytes", id, data.len()); + }; + futures::future::join(write_fut, read_fut).await; + assert_eq!(data, msg); + + Ok(()) +} + +#[derive(Clone, Debug)] +pub struct Msg(pub Vec); + +impl Arbitrary for Msg { + fn arbitrary(g: &mut Gen) -> Msg { + let mut msg = Msg(Arbitrary::arbitrary(g)); + if msg.0.is_empty() { + msg.0.push(Arbitrary::arbitrary(g)); + } + + msg + } + + fn shrink(&self) -> Box> { + Box::new(self.0.shrink().filter(|v| !v.is_empty()).map(Msg)) + } +} + +#[derive(Clone, Debug)] +pub struct TestConfig(pub Config); + +impl Arbitrary for TestConfig { + fn arbitrary(g: &mut Gen) -> Self { + let mut c = Config::default(); + c.set_window_update_mode(if bool::arbitrary(g) { + WindowUpdateMode::OnRead + } else { + WindowUpdateMode::OnReceive + }); + c.set_read_after_close(Arbitrary::arbitrary(g)); + c.set_receive_window(256 * 1024 + u32::arbitrary(g) % (768 * 1024)); + TestConfig(c) + } +} diff --git a/tests/poll_api.rs b/tests/poll_api.rs new file mode 100644 index 00000000..d9466406 --- /dev/null +++ b/tests/poll_api.rs @@ -0,0 +1,198 @@ +#[allow(dead_code)] +mod harness; + +use futures::future::BoxFuture; +use futures::stream::FuturesUnordered; +use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, StreamExt}; +use harness::*; +use quickcheck::QuickCheck; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::net::TcpStream; +use tokio::runtime::Runtime; +use tokio_util::compat::TokioAsyncReadCompatExt; +use yamux::{Connection, Mode}; + +#[test] +fn prop_config_send_recv_multi() { + let _ = env_logger::try_init(); + + fn prop(msgs: Vec, cfg1: TestConfig, cfg2: TestConfig) { + Runtime::new().unwrap().block_on(async move { + let num_messagses = msgs.len(); + + let (listener, address) = bind().await.expect("bind"); + + let server = async { + let socket = listener.accept().await.expect("accept").0.compat(); + let connection = Connection::new(socket, cfg1.0, Mode::Server); + + EchoServer::new(connection).await + }; + + let client = async { + let socket = TcpStream::connect(address).await.expect("connect").compat(); + let connection = Connection::new(socket, cfg2.0, Mode::Client); + + MessageSender::new(connection, msgs).await + }; + + let (server_processed, client_processed) = + futures::future::try_join(server, client).await.unwrap(); + + assert_eq!(server_processed, num_messagses); + assert_eq!(client_processed, num_messagses); + }) + } + + QuickCheck::new().quickcheck(prop as fn(_, _, _) -> _) +} + +struct EchoServer { + connection: Connection, + worker_streams: FuturesUnordered>>, + streams_processed: usize, + connection_closed: bool, +} + +impl EchoServer { + fn new(connection: Connection) -> Self { + Self { + connection, + worker_streams: FuturesUnordered::default(), + streams_processed: 0, + connection_closed: false, + } + } +} + +impl Future for EchoServer +where + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = yamux::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + loop { + match this.worker_streams.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(()))) => { + this.streams_processed += 1; + continue; + } + Poll::Ready(Some(Err(e))) => { + eprintln!("A stream failed: {}", e); + continue; + } + Poll::Ready(None) => { + if this.connection_closed { + return Poll::Ready(Ok(this.streams_processed)); + } + } + Poll::Pending => {} + } + + match this.connection.poll_next_inbound(cx) { + Poll::Ready(Some(Ok(mut stream))) => { + this.worker_streams.push( + async move { + { + let (mut r, mut w) = AsyncReadExt::split(&mut stream); + futures::io::copy(&mut r, &mut w).await?; + } + stream.close().await?; + Ok(()) + } + .boxed(), + ); + continue; + } + Poll::Ready(None) | Poll::Ready(Some(Err(_))) => { + this.connection_closed = true; + continue; + } + Poll::Pending => {} + } + + return Poll::Pending; + } + } +} + +struct MessageSender { + connection: Connection, + pending_messages: Vec, + worker_streams: FuturesUnordered>, + streams_processed: usize, +} + +impl MessageSender { + fn new(connection: Connection, messages: Vec) -> Self { + Self { + connection, + pending_messages: messages, + worker_streams: FuturesUnordered::default(), + streams_processed: 0, + } + } +} + +impl Future for MessageSender +where + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = yamux::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + loop { + if this.pending_messages.is_empty() && this.worker_streams.is_empty() { + futures::ready!(this.connection.poll_close(cx)?); + + return Poll::Ready(Ok(this.streams_processed)); + } + + if let Some(message) = this.pending_messages.pop() { + match this.connection.poll_new_outbound(cx)? { + Poll::Ready(mut stream) => { + this.worker_streams.push( + async move { + send_recv_message(&mut stream, message).await.unwrap(); + stream.close().await.unwrap(); + } + .boxed(), + ); + continue; + } + Poll::Pending => { + this.pending_messages.push(message); + } + } + } + + match this.worker_streams.poll_next_unpin(cx) { + Poll::Ready(Some(())) => { + this.streams_processed += 1; + continue; + } + Poll::Ready(None) | Poll::Pending => {} + } + + match this.connection.poll_next_inbound(cx)? { + Poll::Ready(Some(stream)) => { + drop(stream); + panic!("Did not expect remote to open a stream"); + } + Poll::Ready(None) => { + panic!("Did not expect remote to close the connection"); + } + Poll::Pending => {} + } + + return Poll::Pending; + } + } +} diff --git a/tests/tests.rs b/tests/tests.rs index 59aab23e..bf00fdb5 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -8,29 +8,23 @@ // at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license // at https://opensource.org/licenses/MIT. +#[allow(dead_code)] +mod harness; + use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::executor::LocalPool; use futures::future::join; use futures::io::AsyncReadExt; +use futures::prelude::*; use futures::task::{Spawn, SpawnExt}; -use futures::{future, prelude::*}; -use quickcheck::{Arbitrary, Gen, QuickCheck, TestResult}; +use harness::*; +use harness::{Msg, TestConfig}; +use quickcheck::{QuickCheck, TestResult}; use std::panic::panic_any; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll, Waker}; -use std::{ - fmt::Debug, - io, - net::{Ipv4Addr, SocketAddr, SocketAddrV4}, -}; -use tokio::{ - net::{TcpListener, TcpStream}, - runtime::Runtime, - task, -}; -use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; -use yamux::WindowUpdateMode; +use tokio::{runtime::Runtime, task}; use yamux::{Config, Connection, ConnectionError, Control, Mode}; #[test] @@ -277,101 +271,6 @@ fn write_deadlock() { ); } -#[derive(Clone, Debug)] -struct Msg(Vec); - -impl Arbitrary for Msg { - fn arbitrary(g: &mut Gen) -> Msg { - let mut msg = Msg(Arbitrary::arbitrary(g)); - if msg.0.is_empty() { - msg.0.push(Arbitrary::arbitrary(g)); - } - - msg - } - - fn shrink(&self) -> Box> { - Box::new(self.0.shrink().filter(|v| !v.is_empty()).map(Msg)) - } -} - -#[derive(Clone, Debug)] -struct TestConfig(Config); - -impl Arbitrary for TestConfig { - fn arbitrary(g: &mut Gen) -> Self { - let mut c = Config::default(); - c.set_window_update_mode(if bool::arbitrary(g) { - WindowUpdateMode::OnRead - } else { - WindowUpdateMode::OnReceive - }); - c.set_read_after_close(Arbitrary::arbitrary(g)); - c.set_receive_window(256 * 1024 + u32::arbitrary(g) % (768 * 1024)); - TestConfig(c) - } -} - -async fn connected_peers( - server_config: Config, - client_config: Config, -) -> io::Result<(Connection>, Connection>)> { - let (listener, addr) = bind().await?; - - let server = async { - let (stream, _) = listener.accept().await?; - Ok(Connection::new( - stream.compat(), - server_config, - Mode::Server, - )) - }; - let client = async { - let stream = TcpStream::connect(addr).await?; - Ok(Connection::new( - stream.compat(), - client_config, - Mode::Client, - )) - }; - - futures::future::try_join(server, client).await -} - -async fn bind() -> io::Result<(TcpListener, SocketAddr)> { - let i = Ipv4Addr::new(127, 0, 0, 1); - let s = SocketAddr::V4(SocketAddrV4::new(i, 0)); - let l = TcpListener::bind(&s).await?; - let a = l.local_addr()?; - Ok((l, a)) -} - -/// For each incoming stream of `c` echo back to the sender. -async fn echo_server(mut c: Connection) -> Result<(), ConnectionError> -where - T: AsyncRead + AsyncWrite + Unpin, -{ - stream::poll_fn(|cx| c.poll_next_inbound(cx)) - .try_for_each_concurrent(None, |mut stream| async move { - { - let (mut r, mut w) = AsyncReadExt::split(&mut stream); - futures::io::copy(&mut r, &mut w).await?; - } - stream.close().await?; - Ok(()) - }) - .await -} - -/// For each incoming stream, do nothing. -async fn noop_server(c: impl Stream>) { - c.for_each(|maybe_stream| { - drop(maybe_stream); - future::ready(()) - }) - .await; -} - /// Send all messages, opening a new stream for each one. async fn send_on_separate_streams( mut control: Control, @@ -411,26 +310,6 @@ async fn send_on_single_stream( Ok(()) } -async fn send_recv_message(stream: &mut yamux::Stream, Msg(msg): Msg) -> io::Result<()> { - let id = stream.id(); - let (mut reader, mut writer) = AsyncReadExt::split(stream); - - let len = msg.len(); - let write_fut = async { - writer.write_all(&msg).await.unwrap(); - log::debug!("C: {}: sent {} bytes", id, len); - }; - let mut data = vec![0; msg.len()]; - let read_fut = async { - reader.read_exact(&mut data).await.unwrap(); - log::debug!("C: {}: received {} bytes", id, data.len()); - }; - futures::future::join(write_fut, read_fut).await; - assert_eq!(data, msg); - - Ok(()) -} - /// This module implements a duplex connection via channels with bounded /// capacities. The channels used for the implementation are unbounded /// as the operate at the granularity of variably-sized chunks of bytes