Skip to content

Commit 38e9e69

Browse files
committed
refactor: try out a version without a macro
1 parent 92f15fb commit 38e9e69

File tree

3 files changed

+136
-78
lines changed

3 files changed

+136
-78
lines changed

examples/stream.rs

Lines changed: 18 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,30 @@
11
use std::{
22
collections::BTreeMap,
3-
future::{Future, IntoFuture},
43
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
54
};
65

76
use anyhow::{Context, Result};
8-
use futures_util::future::BoxFuture;
97
use irpc::{
10-
channel::{mpsc, oneshot},
8+
channel::oneshot,
119
rpc::RemoteService,
1210
rpc_requests,
13-
util::{
14-
make_client_endpoint, make_server_endpoint, IrpcReceiverFutExt, MpscSenderExt, StreamItem,
15-
},
11+
util::{make_client_endpoint, make_server_endpoint, MpscSenderExt, Progress, StreamSender},
1612
Client, WithChannels,
1713
};
1814
// Import the macro
1915
use n0_future::{
2016
task::{self, AbortOnDropHandle},
21-
Stream, StreamExt,
17+
StreamExt,
2218
};
2319
use serde::{Deserialize, Serialize};
2420
use tracing::info;
2521

26-
#[derive(Debug, Serialize, Deserialize)]
22+
#[derive(Debug, Serialize, Deserialize, thiserror::Error)]
23+
#[error("{message}")]
2724
struct Error {
2825
message: String,
2926
}
3027

31-
impl std::error::Error for Error {}
32-
33-
impl std::fmt::Display for Error {
34-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35-
write!(f, "{}", self.message)
36-
}
37-
}
38-
39-
#[derive(Debug, Serialize, Deserialize, StreamItem)]
40-
enum GetItem {
41-
Item(String),
42-
Error(Error),
43-
Done,
44-
}
45-
4628
#[derive(Debug, Serialize, Deserialize)]
4729
struct Set {
4830
key: String,
@@ -61,7 +43,7 @@ struct Get {
6143
enum StorageProtocol {
6244
#[rpc(tx=oneshot::Sender<()>)]
6345
Set(Set),
64-
#[rpc(tx=mpsc::Sender<GetItem>)]
46+
#[rpc(tx=StreamSender<String, Error>)]
6547
Get(Get),
6648
}
6749

@@ -70,31 +52,6 @@ struct StorageActor {
7052
state: BTreeMap<String, String>,
7153
}
7254

73-
struct GetProgress {
74-
fut: BoxFuture<'static, irpc::Result<mpsc::Receiver<GetItem>>>,
75-
}
76-
77-
impl GetProgress {
78-
pub fn new(
79-
fut: impl Future<Output = irpc::Result<mpsc::Receiver<GetItem>>> + Send + 'static,
80-
) -> Self {
81-
Self { fut: Box::pin(fut) }
82-
}
83-
84-
pub fn stream(self) -> impl Stream<Item = anyhow::Result<String>> {
85-
self.fut.into_stream()
86-
}
87-
}
88-
89-
impl IntoFuture for GetProgress {
90-
type Output = anyhow::Result<String>;
91-
type IntoFuture = BoxFuture<'static, Self::Output>;
92-
93-
fn into_future(self) -> Self::IntoFuture {
94-
Box::pin(self.fut.try_collect())
95-
}
96-
}
97-
9855
impl StorageActor {
9956
pub fn spawn() -> StorageApi {
10057
let (tx, rx) = tokio::sync::mpsc::channel(1);
@@ -164,8 +121,12 @@ impl StorageApi {
164121
Ok(AbortOnDropHandle::new(join_handle))
165122
}
166123

167-
pub fn get(&self, key: String) -> GetProgress {
168-
GetProgress::new(self.inner.server_streaming(Get { key }, 16))
124+
pub fn get(&self, key: String) -> Progress<String, Error> {
125+
Progress::new(self.inner.server_streaming(Get { key }, 16))
126+
}
127+
128+
pub fn get_vec(&self, key: String) -> Progress<String, Error, Vec<String>> {
129+
Progress::new(self.inner.server_streaming(Get { key }, 16))
169130
}
170131

171132
pub async fn set(&self, key: String, value: String) -> irpc::Result<()> {
@@ -174,9 +135,13 @@ impl StorageApi {
174135
}
175136

176137
async fn client_demo(api: StorageApi) -> Result<()> {
177-
api.set("hello".to_string(), "world".to_string()).await?;
138+
api.set("hello".to_string(), "world and all".to_string())
139+
.await?;
178140
let value = api.get("hello".to_string()).await?;
179-
println!("get: hello = {value:?}");
141+
println!("get (string): hello = {value:?}");
142+
143+
let value = api.get_vec("hello".to_string()).await?;
144+
println!("get (vec): hello = {value:?}");
180145

181146
api.set("loremipsum".to_string(), "dolor sit amet".to_string())
182147
.await?;

irpc-derive/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ fn vis_pub() -> syn::Visibility {
611611
})
612612
}
613613

614+
// TODO(Frando): Remove if the generics approach works out fine?
614615
#[proc_macro_derive(StreamItem)]
615616
pub fn derive_irpc_stream_item(input: TokenStream) -> TokenStream {
616617
let input = parse_macro_input!(input as DeriveInput);

src/util.rs

Lines changed: 117 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -450,11 +450,112 @@ mod now_or_never {
450450
pub(crate) use now_or_never::now_or_never;
451451

452452
mod stream_item {
453-
use std::{future::Future, io};
453+
use std::{
454+
future::{Future, IntoFuture},
455+
io,
456+
marker::PhantomData,
457+
};
454458

459+
use futures_util::future::BoxFuture;
455460
use n0_future::{stream, Stream, StreamExt};
461+
use serde::{Deserialize, Serialize};
462+
463+
use crate::{
464+
channel::{mpsc, RecvError, SendError},
465+
RpcMessage,
466+
};
467+
468+
pub type StreamSender<T, E> = mpsc::Sender<Item<T, E>>;
469+
pub type StreamReceiver<T, E> = mpsc::Receiver<Item<T, E>>;
470+
471+
#[derive(thiserror::Error, Debug)]
472+
pub enum StreamError<E: std::error::Error> {
473+
#[error(transparent)]
474+
Transport(#[from] crate::Error),
475+
#[error(transparent)]
476+
Remote(E),
477+
}
478+
479+
impl<E: std::error::Error> From<crate::channel::RecvError> for StreamError<E> {
480+
fn from(value: crate::channel::RecvError) -> Self {
481+
Self::Transport(value.into())
482+
}
483+
}
484+
485+
pub type StreamResult<T, E> = std::result::Result<T, StreamError<E>>;
486+
487+
pub struct Progress<
488+
T: RpcMessage,
489+
E: RpcMessage + std::error::Error,
490+
C: Extend<T> + Default = T,
491+
> {
492+
fut: BoxFuture<'static, crate::Result<mpsc::Receiver<Item<T, E>>>>,
493+
_collection_type: PhantomData<C>,
494+
}
495+
496+
impl<T, E, C> Progress<T, E, C>
497+
where
498+
T: RpcMessage,
499+
E: RpcMessage + std::error::Error,
500+
C: Extend<T> + Default + Send,
501+
{
502+
pub fn new(
503+
fut: impl Future<Output = crate::Result<mpsc::Receiver<Item<T, E>>>> + Send + 'static,
504+
) -> Self {
505+
Self {
506+
fut: Box::pin(fut),
507+
_collection_type: PhantomData,
508+
}
509+
}
510+
511+
pub fn stream(self) -> impl Stream<Item = StreamResult<T, E>> {
512+
self.fut.into_stream()
513+
}
514+
}
515+
516+
impl<T, E, C> IntoFuture for Progress<T, E, C>
517+
where
518+
T: RpcMessage,
519+
E: RpcMessage + std::error::Error,
520+
C: Default + Extend<T> + Send + 'static,
521+
{
522+
type Output = StreamResult<C, E>;
523+
type IntoFuture = BoxFuture<'static, Self::Output>;
456524

457-
use crate::channel::{mpsc, RecvError, SendError};
525+
fn into_future(self) -> Self::IntoFuture {
526+
Box::pin(self.fut.try_collect())
527+
}
528+
}
529+
530+
#[derive(Debug, Serialize, Deserialize, Clone)]
531+
pub enum Item<T, E> {
532+
Ok(T),
533+
Err(E),
534+
Done,
535+
}
536+
537+
impl<T: RpcMessage, E: RpcMessage + std::error::Error> StreamItem for Item<T, E> {
538+
type Item = T;
539+
type Error = E;
540+
fn into_result_opt(self) -> Option<Result<Self::Item, Self::Error>> {
541+
match self {
542+
Item::Ok(item) => Some(Ok(item)),
543+
Item::Err(error) => Some(Err(error)),
544+
Item::Done => None,
545+
}
546+
}
547+
548+
fn from_result(item: std::result::Result<Self::Item, Self::Error>) -> Self {
549+
match item {
550+
Ok(item) => Self::Ok(item),
551+
Err(err) => Self::Err(err),
552+
}
553+
}
554+
555+
fn done() -> Self {
556+
Self::Done
557+
}
558+
}
458559

459560
/// Trait for an enum that has three variants, item, error, and done.
460561
///
@@ -463,9 +564,9 @@ mod stream_item {
463564
/// for successful end of stream.
464565
pub trait StreamItem: crate::RpcMessage {
465566
/// The error case of the item enum.
466-
type Error;
567+
type Error: crate::RpcMessage + std::error::Error;
467568
/// The item case of the item enum.
468-
type Item;
569+
type Item: crate::RpcMessage;
469570
/// Converts the stream item into either None for end of stream, or a Result
470571
/// containing the item or an error. Error is assumed as a termination, so
471572
/// if you get error you won't get an additional end of stream marker.
@@ -481,7 +582,6 @@ mod stream_item {
481582
///
482583
/// This will convert items and errors into the item enum type, and add
483584
/// a done marker if the stream ends without an error.
484-
#[allow(dead_code)]
485585
fn forward_stream(
486586
self,
487587
stream: impl Stream<Item = std::result::Result<T::Item, T::Error>>,
@@ -538,9 +638,7 @@ mod stream_item {
538638
fn try_collect<C, E>(self) -> impl Future<Output = std::result::Result<C, E>>
539639
where
540640
C: Default + Extend<T::Item>,
541-
E: From<T::Error>,
542-
E: From<crate::Error>,
543-
E: From<RecvError>;
641+
E: From<StreamError<T::Error>>;
544642

545643
/// Converts the receiver returned by this future into a stream of items,
546644
/// where each item is either a successful item or an error.
@@ -550,9 +648,7 @@ mod stream_item {
550648
/// first item and then terminate.
551649
fn into_stream<E>(self) -> impl Stream<Item = std::result::Result<T::Item, E>>
552650
where
553-
E: From<T::Error>,
554-
E: From<crate::Error>,
555-
E: From<RecvError>;
651+
E: From<StreamError<T::Error>>;
556652
}
557653

558654
impl<T, F> IrpcReceiverFutExt<T> for F
@@ -563,9 +659,7 @@ mod stream_item {
563659
async fn try_collect<C, E>(self) -> std::result::Result<C, E>
564660
where
565661
C: Default + Extend<T::Item>,
566-
E: From<T::Error>,
567-
E: From<crate::Error>,
568-
E: From<RecvError>,
662+
E: From<StreamError<T::Error>>,
569663
{
570664
let mut items = C::default();
571665
let mut stream = self.into_stream::<E>();
@@ -580,9 +674,7 @@ mod stream_item {
580674

581675
fn into_stream<E>(self) -> impl Stream<Item = std::result::Result<T::Item, E>>
582676
where
583-
E: From<T::Error>,
584-
E: From<crate::Error>,
585-
E: From<RecvError>,
677+
E: From<StreamError<T::Error>>,
586678
{
587679
enum State<S, T> {
588680
Init(S),
@@ -597,24 +689,22 @@ mod stream_item {
597689
) -> Option<(std::result::Result<T::Item, E>, State<S, T>)>
598690
where
599691
T: StreamItem,
600-
E: From<T::Error>,
601-
E: From<crate::Error>,
602-
E: From<RecvError>,
692+
E: From<StreamError<T::Error>>,
603693
{
604694
match rx.recv().await {
605695
Ok(Some(item)) => match item.into_result_opt()? {
606696
Ok(i) => Some((Ok(i), State::Receiving(rx))),
607-
Err(e) => Some((Err(E::from(e)), State::Done)),
697+
Err(e) => Some((Err(E::from(StreamError::Remote(e))), State::Done)),
608698
},
609-
Ok(None) => Some((Err(E::from(eof())), State::Done)),
610-
Err(e) => Some((Err(E::from(e)), State::Done)),
699+
Ok(None) => Some((Err(E::from(StreamError::from(eof()))), State::Done)),
700+
Err(e) => Some((Err(E::from(StreamError::from(e))), State::Done)),
611701
}
612702
}
613703
Box::pin(stream::unfold(State::Init(self), |state| async move {
614704
match state {
615705
State::Init(fut) => match fut.await {
616706
Ok(rx) => process_recv(rx).await,
617-
Err(e) => Some((Err(E::from(e)), State::Done)),
707+
Err(e) => Some((Err(E::from(StreamError::from(e))), State::Done)),
618708
},
619709
State::Receiving(rx) => process_recv(rx).await,
620710
State::Done => None,
@@ -629,4 +719,6 @@ mod stream_item {
629719
pub use irpc_derive::StreamItem;
630720
#[cfg(feature = "stream")]
631721
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "stream")))]
632-
pub use stream_item::{IrpcReceiverFutExt, MpscSenderExt, StreamItem};
722+
pub use stream_item::{
723+
IrpcReceiverFutExt, Item, MpscSenderExt, Progress, StreamItem, StreamReceiver, StreamSender,
724+
};

0 commit comments

Comments
 (0)