Skip to content

Commit

Permalink
Add test for poll-based API
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaseizinger committed Oct 25, 2022
1 parent 077dcf2 commit 272ddf9
Showing 1 changed file with 236 additions and 0 deletions.
236 changes: 236 additions & 0 deletions tests/poll_api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
use std::future::Future;
use std::io;
use std::net::{Ipv4Addr, SocketAddrV4, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, StreamExt};
use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use quickcheck::{Arbitrary, Gen, QuickCheck};
use tokio::net::{TcpListener, TcpStream};
use tokio::runtime::Runtime;
use tokio_util::compat::TokioAsyncReadCompatExt;
use yamux::{Connection, Mode, WindowUpdateMode};

#[test]
fn prop_config_send_recv_multi() {
let _ = env_logger::try_init();

fn prop(msgs: Vec<Msg>, cfg1: TestConfig, cfg2: TestConfig) {
Runtime::new().unwrap().block_on(async move {
let num_messagses = msgs.len();

let (listener, address) = bind().await.expect("bind");

let server = async {
let socket = listener.accept().await.expect("accept").0.compat();
let connection = Connection::new(socket, cfg1.0, Mode::Server);

EchoServer::new(connection).await
};

let client = async {
let socket = TcpStream::connect(address).await.expect("connect").compat();
let connection = Connection::new(socket, cfg2.0, Mode::Client);

MessageSender::new(connection, msgs).await
};

let (server_processed, client_processed) = futures::future::try_join(server, client).await.unwrap();

assert_eq!(server_processed, num_messagses);
assert_eq!(client_processed, num_messagses);
})
}

QuickCheck::new()
.tests(10)
.quickcheck(prop as fn(_, _, _) -> _)
}

#[derive(Clone, Debug)]
struct Msg(Vec<u8>);

impl Arbitrary for Msg {
fn arbitrary(g: &mut Gen) -> Msg {
let mut msg = Msg(Arbitrary::arbitrary(g));
if msg.0.is_empty() {
msg.0.push(Arbitrary::arbitrary(g));
}

msg
}

fn shrink(&self) -> Box<dyn Iterator<Item=Self>> {
Box::new(self.0.shrink().filter(|v| !v.is_empty()).map(|v| Msg(v)))
}
}

#[derive(Clone, Debug)]
struct TestConfig(yamux::Config);

impl Arbitrary for TestConfig {
fn arbitrary(g: &mut Gen) -> Self {
let mut c = yamux::Config::default();
c.set_window_update_mode(if bool::arbitrary(g) {
WindowUpdateMode::OnRead
} else {
WindowUpdateMode::OnReceive
});
c.set_read_after_close(Arbitrary::arbitrary(g));
c.set_receive_window(256 * 1024 + u32::arbitrary(g) % (768 * 1024));
TestConfig(c)
}
}

async fn bind() -> io::Result<(TcpListener, SocketAddr)> {
let i = Ipv4Addr::new(127, 0, 0, 1);
let s = SocketAddr::V4(SocketAddrV4::new(i, 0));
let l = TcpListener::bind(&s).await?;
let a = l.local_addr()?;
Ok((l, a))
}

struct EchoServer<T> {
connection: Connection<T>,
worker_streams: FuturesUnordered<BoxFuture<'static, yamux::Result<()>>>,
streams_processed: usize
}

impl<T> EchoServer<T> {
fn new(connection: Connection<T>) -> Self {
Self {
connection,
worker_streams: FuturesUnordered::default(),
streams_processed: 0
}
}
}

impl<T> Future for EchoServer<T> where T: AsyncRead + AsyncWrite + Unpin {
type Output = yamux::Result<usize>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();

loop {
match this.worker_streams.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(()))) => {
this.streams_processed += 1;
continue;
}
Poll::Ready(Some(Err(e))) => {
eprintln!("A stream failed: {}", e);
continue;
}
Poll::Ready(None) | Poll::Pending => {}
}

match this.connection.poll_next_inbound(cx)? {
Poll::Ready(Some(mut stream)) => {
this.worker_streams.push(async move {
{
let (mut r, mut w) = AsyncReadExt::split(&mut stream);
futures::io::copy(&mut r, &mut w).await?;
}
stream.close().await?;
Ok(())
}.boxed());
continue;
}
Poll::Ready(None) => return Poll::Ready(Ok(this.streams_processed)),
Poll::Pending => {}
}

return Poll::Pending;
}
}
}

struct MessageSender<T> {
connection: Connection<T>,
pending_messages: Vec<Msg>,
worker_streams: FuturesUnordered<BoxFuture<'static, ()>>,
streams_processed: usize
}

impl<T> MessageSender<T> {
fn new(connection: Connection<T>, messages: Vec<Msg>) -> Self {
Self {
connection,
pending_messages: messages,
worker_streams: FuturesUnordered::default(),
streams_processed: 0
}
}
}

impl<T> Future for MessageSender<T> where T: AsyncRead + AsyncWrite + Unpin {
type Output = yamux::Result<usize>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();

loop {
if this.pending_messages.is_empty() && this.worker_streams.is_empty() {
futures::ready!(this.connection.poll_close(cx)?);

return Poll::Ready(Ok(this.streams_processed));
}

match this.worker_streams.poll_next_unpin(cx) {
Poll::Ready(Some(())) => {
this.streams_processed += 1;
continue;
}
Poll::Ready(None) | Poll::Pending => {}
}

if let Some(Msg(message)) = this.pending_messages.pop() {
match this.connection.poll_new_outbound(cx)? {
Poll::Ready(stream) => {
this.worker_streams.push(async move {
let id = stream.id();
let len = message.len();

let (mut reader, mut writer) = AsyncReadExt::split(stream);

let write_fut = async {
writer.write_all(&message).await.unwrap();
log::debug!("C: {}: sent {} bytes", id, len);
writer.close().await.unwrap();
};

let mut received = Vec::new();
let read_fut = async {
reader.read_to_end(&mut received).await.unwrap();
log::debug!("C: {}: received {} bytes", id, received.len());
};

futures::future::join(write_fut, read_fut).await;

assert_eq!(message, received)
}.boxed());
continue;
}
Poll::Pending => {
this.pending_messages.push(Msg(message));
}
}
}

match this.connection.poll_next_inbound(cx)? {
Poll::Ready(Some(stream)) => {
drop(stream);
panic!("Did not expect remote to open a stream");
}
Poll::Ready(None) => {
panic!("Did not expect remote to close the connection");
},
Poll::Pending => {}
}

return Poll::Pending;
}
}
}

0 comments on commit 272ddf9

Please sign in to comment.