Skip to content

Commit

Permalink
Use oneshot channel in Client::send_request()
Browse files Browse the repository at this point in the history
This changes the `pending_requests` map to use a oneshot async channel
to simplify the code and better capitalize on the async runtime.

The previous implementation used a `loop` to poll the `pending_requests`
hashmap and explicitly yield execution back to the executor by calling
`tokio::task::yield_now()` if the client's response hasn't been received
yet by the server.

This should hopefully be a bit more efficient, given that
`futures::channel::oneshot::channel()` is presumably using the
underlying runtime's task scheduler more efficiently. This assumption
hasn't been tested with dedicated benchmarks, however, but at least the
code is much easier to reason about.

See original suggestion here:
#134 (comment)

We also add a private `RequestMap` type alias to help `rustc` infer the
type of the `pending_requests` map inside the `tokio::spawn()` async
block.
  • Loading branch information
ebkalderon committed Mar 4, 2020
1 parent 2843da6 commit f2bc58b
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions src/delegate/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::sync::Arc;

use dashmap::DashMap;
use futures::channel::mpsc::{Receiver, Sender};
use futures::channel::oneshot;
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use jsonrpc_core::types::{ErrorCode, Id, Output, Version};
Expand All @@ -19,13 +20,16 @@ use serde_json::Value;

use super::not_initialized_error;

/// Maps all pending client request IDs to their future responses.
type RequestMap = DashMap<u64, oneshot::Sender<Output>>;

/// Handle for communicating with the language client.
#[derive(Debug)]
pub struct Client {
sender: Sender<String>,
initialized: Arc<AtomicBool>,
request_id: AtomicU64,
pending_requests: Arc<DashMap<u64, Option<Output>>>,
pending_requests: Arc<RequestMap>,
}

impl Client {
Expand All @@ -34,17 +38,18 @@ impl Client {
mut receiver: Receiver<Output>,
initialized: Arc<AtomicBool>,
) -> Self {
let pending_requests = Arc::new(DashMap::default());
let pending_requests = Arc::new(RequestMap::default());

let pending = pending_requests.clone();
tokio::spawn(async move {
while let Some(response) = receiver.next().await {
match response.id() {
Id::Num(ref id) if pending.contains_key(id) => {
pending.insert(*id, Some(response));
if let Id::Num(ref id) = response.id() {
match pending.remove(id) {
Some((_, tx)) => tx.send(response).expect("receiver already dropped"),
None => error!("received response from client with no matching request"),
}
Id::Num(_) => error!("received response from client with no matching request"),
_ => error!("received response from client with non-numeric ID"),
} else {
error!("received response from client with non-numeric ID");
}
}
});
Expand Down Expand Up @@ -292,25 +297,17 @@ impl Client {
return Err(Error::internal_error());
}

self.pending_requests.insert(id, None);

loop {
let response = self
.pending_requests
.remove_if(&id, |_, v| v.is_some())
.and_then(|(_, v)| v);

match response {
Some(Output::Success(s)) => {
return serde_json::from_value(s.result).map_err(|e| Error {
code: ErrorCode::ParseError,
message: e.to_string(),
data: None,
});
}
Some(Output::Failure(f)) => return Err(f.error),
None => tokio::task::yield_now().await,
}
let (tx, rx) = oneshot::channel();
self.pending_requests.insert(id, tx);
let response = rx.await.expect("sender already dropped");

match response {
Output::Success(s) => serde_json::from_value(s.result).map_err(|e| Error {
code: ErrorCode::ParseError,
message: e.to_string(),
data: None,
}),
Output::Failure(f) => Err(f.error),
}
}

Expand Down

0 comments on commit f2bc58b

Please sign in to comment.