Skip to content

Commit

Permalink
Add stream_select macro (#2262)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xaeroxe authored Jul 31, 2021
1 parent ea07b4b commit 633a905
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 1 deletion.
10 changes: 10 additions & 0 deletions futures-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use proc_macro::TokenStream;
mod executor;
mod join;
mod select;
mod stream_select;

/// The `join!` macro.
#[cfg_attr(fn_like_proc_macro, proc_macro)]
Expand Down Expand Up @@ -54,3 +55,12 @@ pub fn select_biased_internal(input: TokenStream) -> TokenStream {
pub fn test_internal(input: TokenStream, item: TokenStream) -> TokenStream {
crate::executor::test(input, item)
}

/// The `stream_select!` macro.
#[cfg_attr(fn_like_proc_macro, proc_macro)]
#[cfg_attr(not(fn_like_proc_macro), proc_macro_hack::proc_macro_hack)]
pub fn stream_select_internal(input: TokenStream) -> TokenStream {
crate::stream_select::stream_select(input.into())
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
113 changes: 113 additions & 0 deletions futures-macro/src/stream_select.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens};
use syn::{parse::Parser, punctuated::Punctuated, Expr, Index, Token};

/// The `stream_select!` macro.
pub(crate) fn stream_select(input: TokenStream) -> Result<TokenStream, syn::Error> {
let args = Punctuated::<Expr, Token![,]>::parse_terminated.parse2(input)?;
if args.len() < 2 {
return Ok(quote! {
compile_error!("stream select macro needs at least two arguments.")
});
}
let generic_idents = (0..args.len()).map(|i| format_ident!("_{}", i)).collect::<Vec<_>>();
let field_idents = (0..args.len()).map(|i| format_ident!("__{}", i)).collect::<Vec<_>>();
let field_idents_2 = (0..args.len()).map(|i| format_ident!("___{}", i)).collect::<Vec<_>>();
let field_indices = (0..args.len()).map(Index::from).collect::<Vec<_>>();
let args = args.iter().map(|e| e.to_token_stream());

Ok(quote! {
{
#[derive(Debug)]
struct StreamSelect<#(#generic_idents),*> (#(Option<#generic_idents>),*);

enum StreamEnum<#(#generic_idents),*> {
#(
#generic_idents(#generic_idents)
),*,
None,
}

impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamEnum<#(#generic_idents),*>
where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
{
type Item = ITEM;

fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
match self.get_mut() {
#(
Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll_next(cx)
),*,
Self::None => panic!("StreamEnum::None should never be polled!"),
}
}
}

impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*>
where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
{
type Item = ITEM;

fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
let Self(#(ref mut #field_idents),*) = self.get_mut();
#(
let mut #field_idents_2 = false;
)*
let mut any_pending = false;
{
let mut stream_array = [#(#field_idents.as_mut().map(|f| StreamEnum::#generic_idents(f)).unwrap_or(StreamEnum::None)),*];
__futures_crate::async_await::shuffle(&mut stream_array);

for mut s in stream_array {
if let StreamEnum::None = s {
continue;
} else {
match __futures_crate::stream::Stream::poll_next(::std::pin::Pin::new(&mut s), cx) {
r @ __futures_crate::task::Poll::Ready(Some(_)) => {
return r;
},
__futures_crate::task::Poll::Pending => {
any_pending = true;
},
__futures_crate::task::Poll::Ready(None) => {
match s {
#(
StreamEnum::#generic_idents(_) => { #field_idents_2 = true; }
),*,
StreamEnum::None => panic!("StreamEnum::None should never be polled!"),
}
},
}
}
}
}
#(
if #field_idents_2 {
*#field_idents = None;
}
)*
if any_pending {
__futures_crate::task::Poll::Pending
} else {
__futures_crate::task::Poll::Ready(None)
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
let mut s = (0, Some(0));
#(
if let Some(new_hint) = self.#field_indices.as_ref().map(|s| s.size_hint()) {
s.0 += new_hint.0;
// We can change this out for `.zip` when the MSRV is 1.46.0 or higher.
s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b));
}
)*
s
}
}

StreamSelect(#(Some(#args)),*)

}
})
}
7 changes: 7 additions & 0 deletions futures-util/src/async_await/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ mod select_mod;
#[cfg(feature = "async-await-macro")]
pub use self::select_mod::*;

// Primary export is a macro
#[cfg(feature = "async-await-macro")]
mod stream_select_mod;
#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/64762
#[cfg(feature = "async-await-macro")]
pub use self::stream_select_mod::*;

#[cfg(feature = "std")]
#[cfg(feature = "async-await-macro")]
mod random;
Expand Down
45 changes: 45 additions & 0 deletions futures-util/src/async_await/stream_select_mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//! The `stream_select` macro.

#[cfg(feature = "std")]
#[allow(unreachable_pub)]
#[doc(hidden)]
#[cfg_attr(not(fn_like_proc_macro), proc_macro_hack::proc_macro_hack(support_nested))]
pub use futures_macro::stream_select_internal;

/// Combines several streams, all producing the same `Item` type, into one stream.
/// This is similar to `select_all` but does not require the streams to all be the same type.
/// It also keeps the streams inline, and does not require `Box<dyn Stream>`s to be allocated.
/// Streams passed to this macro must be `Unpin`.
///
/// If multiple streams are ready, one will be pseudo randomly selected at runtime.
///
/// This macro is gated behind the `async-await` feature of this library, which is activated by default.
/// Note that `stream_select!` relies on `proc-macro-hack`, and may require to set the compiler's recursion
/// limit very high, e.g. `#![recursion_limit="1024"]`.
///
/// # Examples
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::{stream, StreamExt, stream_select};
/// let endless_ints = |i| stream::iter(vec![i].into_iter().cycle()).fuse();
///
/// let mut endless_numbers = stream_select!(endless_ints(1i32), endless_ints(2), endless_ints(3));
/// match endless_numbers.next().await {
/// Some(1) => println!("Got a 1"),
/// Some(2) => println!("Got a 2"),
/// Some(3) => println!("Got a 3"),
/// _ => unreachable!(),
/// }
/// # });
/// ```
#[cfg(feature = "std")]
#[macro_export]
macro_rules! stream_select {
($($tokens:tt)*) => {{
use $crate::__private as __futures_crate;
$crate::stream_select_internal! {
$( $tokens )*
}
}}
}
5 changes: 5 additions & 0 deletions futures/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ pub use futures_util::{join, pending, poll, select_biased, try_join}; // Async-a
#[doc(inline)]
pub use futures_util::{future, sink, stream, task};

#[cfg(feature = "std")]
#[cfg(feature = "async-await")]
pub use futures_util::stream_select;

#[cfg_attr(feature = "cfg-target-has-atomic", cfg(target_has_atomic = "ptr"))]
#[cfg(feature = "alloc")]
#[doc(inline)]
pub use futures_channel as channel;
Expand Down
40 changes: 39 additions & 1 deletion futures/tests/async_await_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use futures::future::{self, poll_fn, FutureExt};
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use futures::task::{Context, Poll};
use futures::{join, pending, pin_mut, poll, select, select_biased, try_join};
use futures::{
join, pending, pin_mut, poll, select, select_biased, stream, stream_select, try_join,
};
use std::mem;

#[test]
Expand Down Expand Up @@ -308,6 +310,42 @@ fn select_on_mutable_borrowing_future_with_same_borrow_in_block_and_default() {
});
}

#[test]
#[allow(unused_assignments)]
fn stream_select() {
// stream_select! macro
block_on(async {
let endless_ints = |i| stream::iter(vec![i].into_iter().cycle());

let mut endless_ones = stream_select!(endless_ints(1i32), stream::pending());
assert_eq!(endless_ones.next().await, Some(1));
assert_eq!(endless_ones.next().await, Some(1));

let mut finite_list =
stream_select!(stream::iter(vec![1].into_iter()), stream::iter(vec![1].into_iter()));
assert_eq!(finite_list.next().await, Some(1));
assert_eq!(finite_list.next().await, Some(1));
assert_eq!(finite_list.next().await, None);

let endless_mixed = stream_select!(endless_ints(1i32), endless_ints(2), endless_ints(3));
// Take 1000, and assert a somewhat even distribution of values.
// The fairness is randomized, but over 1000 samples we should be pretty close to even.
// This test may be a bit flaky. Feel free to adjust the margins as you see fit.
let mut count = 0;
let results = endless_mixed
.take_while(move |_| {
count += 1;
let ret = count < 1000;
async move { ret }
})
.collect::<Vec<_>>()
.await;
assert!(results.iter().filter(|x| **x == 1).count() >= 299);
assert!(results.iter().filter(|x| **x == 2).count() >= 299);
assert!(results.iter().filter(|x| **x == 3).count() >= 299);
});
}

#[test]
fn join_size() {
let fut = async {
Expand Down

0 comments on commit 633a905

Please sign in to comment.