Skip to content

Commit

Permalink
client: use tokio channels (#999)
Browse files Browse the repository at this point in the history
* client: use tokio channels

This PR replaces the future channels with tokio because the APIs fit our use-cases better.

* remove unused code

* fix wasm build

* fix docs

* fix tests

* fix more nits

* Update core/src/client/async_client/mod.rs

* Update core/src/client/async_client/mod.rs

* fix unwrap
  • Loading branch information
niklasad1 authored Jan 31, 2023
1 parent a330dae commit 1a2a199
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 65 deletions.
3 changes: 1 addition & 2 deletions client/ws-client/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ async fn notification_without_polling_doesnt_make_client_unuseable() {
// don't poll the notification stream for 2 seconds, should be full now.
tokio::time::sleep(std::time::Duration::from_secs(2)).await;

// Capacity is `num_sender` + `capacity`
for _ in 0..5 {
for _ in 0..4 {
assert!(nh.next().with_default_timeout().await.unwrap().unwrap().is_ok());
}

Expand Down
9 changes: 6 additions & 3 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ license = "MIT"
anyhow = "1"
async-trait = "0.1"
beef = { version = "0.5.1", features = ["impl_serde"] }
futures-channel = "0.3.14"
jsonrpsee-types = { path = "../types", version = "0.16.2" }
thiserror = "1"
serde = { version = "1.0", default-features = false, features = ["derive"] }
Expand All @@ -28,9 +27,11 @@ soketto = { version = "0.7.1", optional = true }
parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.16", optional = true }
wasm-bindgen-futures = { version = "0.4.19", optional = true }
futures-channel = { version = "0.3.14", optional = true }
futures-timer = { version = "3", optional = true }
globset = { version = "0.4", optional = true }
http = { version = "0.2.7", optional = true }
tokio-stream = { version = "0.1", optional = true }

[features]
default = []
Expand All @@ -44,15 +45,16 @@ server = [
"rand",
"tokio/rt",
"tokio/sync",
"futures-channel",
]
client = ["futures-util/sink", "futures-channel/sink", "futures-channel/std"]
client = ["futures-util/sink", "tokio/sync"]
async-client = [
"async-lock",
"client",
"rustc-hash",
"tokio/macros",
"tokio/rt",
"tokio/sync",
"tokio-stream",
"futures-timer",
]
async-wasm-client = [
Expand All @@ -61,6 +63,7 @@ async-wasm-client = [
"wasm-bindgen-futures",
"rustc-hash/std",
"futures-timer/wasm-bindgen",
"tokio-stream",
]

[dev-dependencies]
Expand Down
8 changes: 4 additions & 4 deletions core/src/client/async_client/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ use crate::params::ArrayParams;
use crate::traits::ToRpcParams;
use crate::Error;

use futures_channel::mpsc;
use futures_timer::Delay;
use futures_util::future::{self, Either};
use tokio::sync::{mpsc, oneshot};

use jsonrpsee_types::error::CallError;
use jsonrpsee_types::response::SubscriptionError;
Expand Down Expand Up @@ -155,7 +155,7 @@ pub(crate) fn process_notification(manager: &mut RequestManager, notif: Notifica
Err(err) => {
tracing::error!("Error sending notification, dropping handler for {:?} error: {:?}", notif.method, err);
let _ = manager.remove_notification_handler(notif.method.into_owned());
Err(err.into_send_error().into())
Err(Error::Custom(err.to_string()))
}
},
None => {
Expand Down Expand Up @@ -274,8 +274,8 @@ pub(crate) fn process_error_response(manager: &mut RequestManager, err: ErrorRes
/// Wait for a stream to complete within the given timeout.
pub(crate) async fn call_with_timeout<T>(
timeout: std::time::Duration,
rx: futures_channel::oneshot::Receiver<Result<T, Error>>,
) -> Result<Result<T, Error>, futures_channel::oneshot::Canceled> {
rx: oneshot::Receiver<Result<T, Error>>,
) -> Result<Result<T, Error>, oneshot::error::RecvError> {
match future::select(rx, Delay::new(timeout)).await {
Either::Left((res, _)) => res,
Either::Right((_, _)) => Ok(Err(Error::RequestTimeout)),
Expand Down
4 changes: 2 additions & 2 deletions core/src/client/async_client/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ use std::{
};

use crate::{client::BatchEntry, Error};
use futures_channel::{mpsc, oneshot};
use jsonrpsee_types::{Id, SubscriptionId};
use rustc_hash::FxHashMap;
use serde_json::value::Value as JsonValue;
use tokio::sync::{mpsc, oneshot};

#[derive(Debug)]
enum Kind {
Expand Down Expand Up @@ -312,9 +312,9 @@ impl RequestManager {
#[cfg(test)]
mod tests {
use super::{Error, RequestManager};
use futures_channel::{mpsc, oneshot};
use jsonrpsee_types::{Id, SubscriptionId};
use serde_json::Value as JsonValue;
use tokio::sync::{mpsc, oneshot};

#[test]
fn insert_remove_pending_request_works() {
Expand Down
97 changes: 53 additions & 44 deletions core/src/client/async_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,16 @@ use manager::RequestManager;

use async_lock::Mutex;
use async_trait::async_trait;
use futures_channel::{mpsc, oneshot};
use futures_timer::Delay;
use futures_util::future::{self, Either, Fuse};
use futures_util::sink::SinkExt;
use futures_util::stream::StreamExt;
use futures_util::FutureExt;
use jsonrpsee_types::{
response::SubscriptionError, ErrorResponse, Notification, NotificationSer, RequestSer, Response,
SubscriptionResponse,
};
use serde::de::DeserializeOwned;
use tokio::sync::{mpsc, oneshot};
use tracing::instrument;

use super::{generate_batch_id_range, FrontToBack, IdKind, RequestIdManager};
Expand Down Expand Up @@ -115,8 +114,10 @@ impl ClientBuilder {
/// [`Subscription::next()`](../../jsonrpsee_core/client/struct.Subscription.html#method.next) such that
/// it can keep with the rate as server produces new items on the subscription.
///
/// **Note**: The actual capacity is `num_senders + max_subscription_capacity`
/// because it is passed to [`futures_channel::mpsc::channel`].
///
/// # Panics
///
/// This function panics if `max` is 0.
pub fn max_notifs_per_subscription(mut self, max: usize) -> Self {
self.max_notifs_per_subscription = max;
self
Expand Down Expand Up @@ -169,7 +170,7 @@ impl ClientBuilder {
let (err_tx, err_rx) = oneshot::channel();
let max_notifs_per_subscription = self.max_notifs_per_subscription;
let ping_interval = self.ping_interval;
let (on_close_tx, on_close_rx) = oneshot::channel();
let (on_exit_tx, on_exit_rx) = oneshot::channel();

tokio::spawn(async move {
background_task(
Expand All @@ -179,7 +180,7 @@ impl ClientBuilder {
err_tx,
max_notifs_per_subscription,
ping_interval,
on_close_tx,
on_exit_rx,
)
.await;
});
Expand All @@ -189,7 +190,7 @@ impl ClientBuilder {
error: Mutex::new(ErrorFromBack::Unread(err_rx)),
id_manager: RequestIdManager::new(self.max_concurrent_requests, self.id_kind),
max_log_length: self.max_log_length,
notify: Mutex::new(Some(on_close_rx)),
on_exit: Some(on_exit_tx),
}
}

Expand All @@ -204,18 +205,18 @@ impl ClientBuilder {
let (to_back, from_front) = mpsc::channel(self.max_concurrent_requests);
let (err_tx, err_rx) = oneshot::channel();
let max_notifs_per_subscription = self.max_notifs_per_subscription;
let (on_close_tx, on_close_rx) = oneshot::channel();
let (on_exit_tx, on_exit_rx) = oneshot::channel();

wasm_bindgen_futures::spawn_local(async move {
background_task(sender, receiver, from_front, err_tx, max_notifs_per_subscription, None, on_close_tx).await;
background_task(sender, receiver, from_front, err_tx, max_notifs_per_subscription, None, on_exit_rx).await;
});
Client {
to_back,
request_timeout: self.request_timeout,
error: Mutex::new(ErrorFromBack::Unread(err_rx)),
id_manager: RequestIdManager::new(self.max_concurrent_requests, self.id_kind),
max_log_length: self.max_log_length,
notify: Mutex::new(Some(on_close_rx)),
on_exit: Some(on_exit_tx),
}
}
}
Expand All @@ -236,10 +237,8 @@ pub struct Client {
///
/// Entries bigger than this limit will be truncated.
max_log_length: u32,
/// Notify when the client is disconnected or encountered an error.
// NOTE: Similar to error, the async fns use immutable references. The `Receiver` is wrapped
// into `Option` to ensure the `on_disconnect` awaits only once.
notify: Mutex<Option<oneshot::Receiver<()>>>,
/// When the client is dropped a message is sent to the background thread.
on_exit: Option<tokio::sync::oneshot::Sender<()>>,
}

impl Client {
Expand All @@ -264,17 +263,15 @@ impl Client {
///
/// This method is cancel safe.
pub async fn on_disconnect(&self) {
// Wait until the `background_task` exits.
let mut notify_lock = self.notify.lock().await;
if let Some(notify) = notify_lock.take() {
let _ = notify.await;
}
self.to_back.closed().await
}
}

impl Drop for Client {
fn drop(&mut self) {
self.to_back.close_channel();
if let Some(e) = self.on_exit.take() {
let _ = e.send(());
}
}
}

Expand All @@ -293,9 +290,11 @@ impl ClientT for Client {
let raw = serde_json::to_string(&notif).map_err(Error::ParseError)?;
tx_log_from_str(&raw, self.max_log_length);

let mut sender = self.to_back.clone();
let sender = self.to_back.clone();
let fut = sender.send(FrontToBack::Notification(raw));

tokio::pin!(fut);

match future::select(fut, Delay::new(self.request_timeout)).await {
Either::Left((Ok(()), _)) => Ok(()),
Either::Left((Err(_), _)) => Err(self.read_error_from_backend().await),
Expand Down Expand Up @@ -434,7 +433,7 @@ impl SubscriptionClientT for Client {

tx_log_from_str(&raw, self.max_log_length);

let (send_back_tx, send_back_rx) = oneshot::channel();
let (send_back_tx, send_back_rx) = tokio::sync::oneshot::channel();
if self
.to_back
.clone()
Expand Down Expand Up @@ -698,43 +697,46 @@ async fn handle_frontend_messages<S: TransportSenderT>(
async fn background_task<S, R>(
mut sender: S,
receiver: R,
mut frontend: mpsc::Receiver<FrontToBack>,
frontend: mpsc::Receiver<FrontToBack>,
front_error: oneshot::Sender<Error>,
max_notifs_per_subscription: usize,
ping_interval: Option<Duration>,
on_close: oneshot::Sender<()>,
on_exit: oneshot::Receiver<()>,
) where
S: TransportSenderT,
R: TransportReceiverT,
{
// Create either a valid delay fuse triggered every provided `duration`,
// or create a terminated fuse that's never selected if the provided `duration` is None.
fn ping_fut(ping_interval: Option<Duration>) -> Fuse<Delay> {
if let Some(duration) = ping_interval {
Delay::new(duration).fuse()
} else {
// The select macro bypasses terminated futures, and the `submit_ping` branch is never selected.
Fuse::<Delay>::terminated()
}
}

let mut manager = RequestManager::new();

let backend_event = futures_util::stream::unfold(receiver, |mut receiver| async {
let res = receiver.receive().await;
Some((res, receiver))
});
futures_util::pin_mut!(backend_event);
let frontend = tokio_stream::wrappers::ReceiverStream::new(frontend);

tokio::pin!(backend_event, frontend);

// Place frontend and backend messages into their own select.
// This implies that either messages are received (both front or backend),
// or the submitted ping timer expires (if provided).
let next_frontend = frontend.next();
let next_backend = backend_event.next();
let mut message_fut = future::select(next_frontend, next_backend);
let mut message_fut = future::select(frontend.next(), backend_event.next());
let mut exit_or_ping_fut = future::select(on_exit, ping_fut(ping_interval));

loop {
// Create either a valid delay fuse triggered every provided `duration`,
// or create a terminated fuse that's never selected if the provided `duration` is None.
let submit_ping = if let Some(duration) = ping_interval {
Delay::new(duration).fuse()
} else {
// The select macro bypasses terminated futures, and the `submit_ping` branch is never selected.
Fuse::<Delay>::terminated()
};

match future::select(message_fut, submit_ping).await {
match future::select(message_fut, exit_or_ping_fut).await {
// Message received from the frontend.
Either::Left((Either::Left((frontend_value, backend)), _)) => {
Either::Left((Either::Left((frontend_value, backend)), exit_or_ping)) => {
let frontend_value = if let Some(value) = frontend_value {
value
} else {
Expand All @@ -748,9 +750,10 @@ async fn background_task<S, R>(

// Advance frontend, save backend.
message_fut = future::select(frontend.next(), backend);
exit_or_ping_fut = exit_or_ping;
}
// Message received from the backend.
Either::Left((Either::Right((backend_value, frontend)), _)) => {
Either::Left((Either::Right((backend_value, frontend)), exit_or_ping)) => {
if let Err(err) = handle_backend_messages::<S, R>(
backend_value,
&mut manager,
Expand All @@ -765,23 +768,29 @@ async fn background_task<S, R>(
}
// Advance backend, save frontend.
message_fut = future::select(frontend, backend_event.next());
exit_or_ping_fut = exit_or_ping;
}
// The client is closed.
Either::Right((Either::Left((_, _)), _)) => {
break;
}
// Submit ping interval was triggered if enabled.
Either::Right((_, next_message_fut)) => {
Either::Right((Either::Right((_, on_exit)), msg)) => {
if let Err(err) = sender.send_ping().await {
tracing::error!("[backend]: Could not send ping frame: {}", err);
let _ = front_error.send(Error::Custom("Could not send ping frame".into()));
break;
}
message_fut = next_message_fut;
message_fut = msg;
exit_or_ping_fut = future::select(on_exit, ping_fut(ping_interval));
}
};
}

// Wake the `on_disconnect` method.
let _ = on_close.send(());
frontend.close();
// Send close message to the server.
let _ = sender.close().await;
_ = sender.close().await;
}

fn unparse_error(raw: &[u8]) -> Error {
Expand Down
10 changes: 5 additions & 5 deletions core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ use crate::params::BatchRequestBuilder;
use crate::traits::ToRpcParams;
use async_trait::async_trait;
use core::marker::PhantomData;
use futures_channel::{mpsc, oneshot};
use futures_util::future::FutureExt;
use futures_util::sink::SinkExt;
use futures_util::stream::{Stream, StreamExt};
use jsonrpsee_types::{ErrorObject, Id, SubscriptionId};
use serde::de::DeserializeOwned;
use serde_json::Value as JsonValue;
use tokio::sync::{mpsc, oneshot};

// Re-exports for the `rpc_params` macro.
#[doc(hidden)]
Expand Down Expand Up @@ -256,10 +255,11 @@ impl<Notif> Subscription<Notif> {
SubscriptionKind::Method(notif) => FrontToBack::UnregisterNotification(notif),
SubscriptionKind::Subscription(sub_id) => FrontToBack::SubscriptionClosed(sub_id),
};
self.to_back.send(msg).await?;
// If this fails the connection was already closed i.e, already "unsubscribed".
let _ = self.to_back.send(msg).await;

// wait until notif channel is closed then the subscription was closed.
while self.notifs_rx.next().await.is_some() {}
while self.notifs_rx.recv().await.is_some() {}
Ok(())
}
}
Expand Down Expand Up @@ -360,7 +360,7 @@ where
{
type Item = Result<Notif, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Option<Self::Item>> {
let n = futures_util::ready!(self.notifs_rx.poll_next_unpin(cx));
let n = futures_util::ready!(self.notifs_rx.poll_recv(cx));
let res = n.map(|n| match serde_json::from_value::<Notif>(n) {
Ok(parsed) => Ok(parsed),
Err(e) => Err(Error::ParseError(e)),
Expand Down
Loading

0 comments on commit 1a2a199

Please sign in to comment.