Skip to content
This repository has been archived by the owner on Sep 4, 2024. It is now read-only.

Commit

Permalink
conditional compilation for async roundtripper
Browse files Browse the repository at this point in the history
Signed-off-by: Gregory Hill <gregorydhill@outlook.com>
  • Loading branch information
gregdhill committed Oct 9, 2020
1 parent e651798 commit 03d13f8
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 95 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ documentation = "https://docs.rs/jsonrpc/"
description = "Rust support for the JSON-RPC 2.0 protocol"
keywords = [ "protocol", "json", "http", "jsonrpc" ]
readme = "README.md"
edition = "2018"

[features]
async = []

[lib]
name = "jsonrpc"
Expand Down
262 changes: 173 additions & 89 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
//! and parsing responses
//!
use std::{error, io};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::{error, io};

use serde;
use base64;
use http;
use serde;
use serde_json;

use super::{Request, Response};
use util::HashableValue;
use error::Error;
use crate::error::Error;
use crate::util::HashableValue;

/// An interface for an HTTP roundtripper that handles HTTP requests.
pub trait HttpRoundTripper {
Expand All @@ -38,11 +38,27 @@ pub trait HttpRoundTripper {
/// The type for errors generated by the roundtripper.
type Err: error::Error;

/// Make an HTTP request. In practice only POST request will be made.
/// Make a synchronous HTTP request. In practice only POST request will be made.
#[cfg(not(feature = "async"))]
fn request(
&self,
http::Request<&[u8]>,
_request: http::Request<&[u8]>,
) -> Result<http::Response<Self::ResponseBody>, Self::Err>;

/// Make an asynchronous HTTP request. In practice only POST request will be made.
#[cfg(feature = "async")]
fn request<'life>(
&'life self,
_request: http::Request<&'life [u8]>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<http::Response<Self::ResponseBody>, Self::Err>>
+ Send
+ 'life,
>,
>
where
Self: Sync + 'life;
}

/// A handle to a remote JSONRPC server
Expand All @@ -54,7 +70,21 @@ pub struct Client<R: HttpRoundTripper> {
nonce: Arc<Mutex<u64>>,
}

impl<Rt: HttpRoundTripper + 'static> Client<Rt> {
#[cfg(not(feature = "async"))]
macro_rules! maybe_async_fn {
($($tokens:tt)*) => {
$($tokens)*
};
}

#[cfg(feature = "async")]
macro_rules! maybe_async_fn {
($(#[$($meta:meta)*])* $vis:vis $ident:ident $($tokens:tt)*) => {
$(#[$($meta)*])* $vis async $ident $($tokens)*
};
}

impl<Rt: HttpRoundTripper + 'static + Sync> Client<Rt> {
/// Creates a new client
pub fn new(
roundtripper: Rt,
Expand All @@ -74,104 +104,139 @@ impl<Rt: HttpRoundTripper + 'static> Client<Rt> {
}
}

/// Make a request and deserialize the response
pub fn do_rpc<T: for<'a> serde::de::Deserialize<'a>>(
&self,
rpc_name: &str,
args: &[serde_json::value::Value],
) -> Result<T, Error> {
let request = self.build_request(rpc_name, args);
let response = self.send_request(&request)?;
maybe_async_fn! {
/// Make a request and deserialize the response
pub fn do_rpc<T: for<'a> serde::de::Deserialize<'a>>(
&self,
rpc_name: &str,
args: &[serde_json::value::Value],
) -> Result<T, Error> {
let request = self.build_request(rpc_name, args);

#[cfg(not(feature = "async"))]
let response = self.send_request(&request)?;

Ok(response.into_result()?)
#[cfg(feature = "async")]
let response = self.send_request(&request).await?;

Ok(response.into_result()?)
}
}

/// The actual send logic used by both [send_request] and [send_batch].
fn send_raw<B, R>(&self, body: &B) -> Result<R, Error>
where
B: serde::ser::Serialize,
R: for<'de> serde::de::Deserialize<'de>,
{
// Build request
let request_raw = serde_json::to_vec(body)?;

// Send request
let mut request_builder = http::Request::post(&self.url);
request_builder.header("Content-Type", "application/json-rpc");

// Set Authorization header
if let Some(ref user) = self.user {
let mut auth = user.clone();
auth.push(':');
if let Some(ref pass) = self.pass {
auth.push_str(&pass[..]);
maybe_async_fn! {
/// The actual send logic used by both [send_request] and [send_batch].
fn send_raw<B, R>(&self, body: &B) -> Result<R, Error>
where
B: serde::ser::Serialize,
R: for<'de> serde::de::Deserialize<'de>,
{
// Build request
let request_raw = serde_json::to_vec(body)?;

// Send request
let mut request_builder = http::Request::post(&self.url);
request_builder.header("Content-Type", "application/json-rpc");

// Set Authorization header
if let Some(ref user) = self.user {
let mut auth = user.clone();
auth.push(':');
if let Some(ref pass) = self.pass {
auth.push_str(&pass[..]);
}
let value = format!("Basic {}", &base64::encode(auth.as_bytes()));
request_builder.header("Authorization", value);
}
let value = format!("Basic {}", &base64::encode(auth.as_bytes()));
request_builder.header("Authorization", value);
}

// Errors only on invalid header or builder reuse.
let http_request = request_builder.body(&request_raw[..]).unwrap();
// Errors only on invalid header or builder reuse.
let http_request = request_builder.body(&request_raw[..]).unwrap();

let http_response =
self.roundtripper.request(http_request).map_err(|e| Error::Http(Box::new(e)))?;
#[cfg(not(feature = "async"))]
let http_response = self
.roundtripper
.request(http_request)
.map_err(|e| Error::Http(Box::new(e)))?;

// nb we ignore stream.status since we expect the body
// to contain information about any error
Ok(serde_json::from_reader(http_response.into_body())?)
}
#[cfg(feature = "async")]
let http_response = self
.roundtripper
.request(http_request).await
.map_err(|e| Error::Http(Box::new(e)))?;

/// Sends a request to a client
pub fn send_request(&self, request: &Request) -> Result<Response, Error> {
let response: Response = self.send_raw(&request)?;
if response.jsonrpc != None && response.jsonrpc != Some(From::from("2.0")) {
return Err(Error::VersionMismatch);
}
if response.id != request.id {
return Err(Error::NonceMismatch);

// nb we ignore stream.status since we expect the body
// to contain information about any error
Ok(serde_json::from_reader(http_response.into_body())?)
}
Ok(response)
}

/// Sends a batch of requests to the client. The return vector holds the response
/// for the request at the corresponding index. If no response was provided, it's [None].
///
/// Note that the requests need to have valid IDs, so it is advised to create the requests
/// with [build_request].
pub fn send_batch(&self, requests: &[Request]) -> Result<Vec<Option<Response>>, Error> {
if requests.len() < 1 {
return Err(Error::EmptyBatch);
}
maybe_async_fn! {
/// Sends a request to a client
pub fn send_request<'a, 'b>(&self, request: &Request<'a, 'b>) -> Result<Response, Error> {
#[cfg(not(feature = "async"))]
let response: Response = self.send_raw(&request)?;

// If the request body is invalid JSON, the response is a single response object.
// We ignore this case since we are confident we are producing valid JSON.
let responses: Vec<Response> = self.send_raw(&requests)?;
if responses.len() > requests.len() {
return Err(Error::WrongBatchResponseSize);
}
#[cfg(feature = "async")]
let response: Response = self.send_raw(&request).await?;

// To prevent having to clone responses, we first copy all the IDs so we can reference
// them easily. IDs can only be of JSON type String or Number (or Null), so cloning
// should be inexpensive and require no allocations as Numbers are more common.
let ids: Vec<serde_json::Value> = responses.iter().map(|r| r.id.clone()).collect();
// First index responses by ID and catch duplicate IDs.
let mut resp_by_id = HashMap::new();
for (id, resp) in ids.iter().zip(responses.into_iter()) {
if let Some(dup) = resp_by_id.insert(HashableValue(&id), resp) {
return Err(Error::BatchDuplicateResponseId(dup.id));
if response.jsonrpc != None && response.jsonrpc != Some(From::from("2.0")) {
return Err(Error::VersionMismatch);
}
if response.id != request.id {
return Err(Error::NonceMismatch);
}
Ok(response)
}
// Match responses to the requests.
let results =
requests.into_iter().map(|r| resp_by_id.remove(&HashableValue(&r.id))).collect();

// Since we're also just producing the first duplicate ID, we can also just produce the
// first incorrect ID in case there are multiple.
if let Some(incorrect) = resp_by_id.into_iter().nth(0) {
return Err(Error::WrongBatchResponseId(incorrect.1.id));
}
}

maybe_async_fn! {
/// Sends a batch of requests to the client. The return vector holds the response
/// for the request at the corresponding index. If no response was provided, it's [None].
///
/// Note that the requests need to have valid IDs, so it is advised to create the requests
/// with [build_request].
pub fn send_batch<'a, 'b>(&self, requests: &[Request<'a, 'b>]) -> Result<Vec<Option<Response>>, Error> {
if requests.len() < 1 {
return Err(Error::EmptyBatch);
}

// If the request body is invalid JSON, the response is a single response object.
// We ignore this case since we are confident we are producing valid JSON.
#[cfg(not(feature = "async"))]
let responses: Vec<Response> = self.send_raw(&requests)?;

Ok(results)
#[cfg(feature = "async")]
let responses: Vec<Response> = self.send_raw(&requests).await?;

if responses.len() > requests.len() {
return Err(Error::WrongBatchResponseSize);
}

// To prevent having to clone responses, we first copy all the IDs so we can reference
// them easily. IDs can only be of JSON type String or Number (or Null), so cloning
// should be inexpensive and require no allocations as Numbers are more common.
let ids: Vec<serde_json::Value> = responses.iter().map(|r| r.id.clone()).collect();
// First index responses by ID and catch duplicate IDs.
let mut resp_by_id = HashMap::new();
for (id, resp) in ids.iter().zip(responses.into_iter()) {
if let Some(dup) = resp_by_id.insert(HashableValue(&id), resp) {
return Err(Error::BatchDuplicateResponseId(dup.id));
}
}
// Match responses to the requests.
let results = requests
.into_iter()
.map(|r| resp_by_id.remove(&HashableValue(&r.id)))
.collect();

// Since we're also just producing the first duplicate ID, we can also just produce the
// first incorrect ID in case there are multiple.
if let Some(incorrect) = resp_by_id.into_iter().nth(0) {
return Err(Error::WrongBatchResponseId(incorrect.1.id));
}

Ok(results)
}
}

/// Builds a request
Expand Down Expand Up @@ -206,12 +271,31 @@ mod tests {
type ResponseBody = io::Empty;
type Err = io::Error;

#[cfg(not(feature = "async"))]
fn request(
&self,
_: http::Request<&[u8]>,
) -> Result<http::Response<Self::ResponseBody>, Self::Err> {
Err(io::ErrorKind::Other.into())
}

#[cfg(feature = "async")]
fn request<'life>(
&'life self,
request: http::Request<&[u8]>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<http::Response<Self::ResponseBody>, Self::Err>,
> + Send
+ 'life,
>,
>
where
Self: Sync + 'life,
{
Box::pin(async { Err(io::ErrorKind::Other.into()) })
}
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::{error, fmt};

use serde_json;

use Response;
use crate::Response;

/// A library error
#[derive(Debug)]
Expand Down
8 changes: 3 additions & 5 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ impl<'a> Hash for HashableValue<'a> {
} else {
n.to_string().hash(state);
}
},
}
Value::String(ref s) => {
"string".hash(state);
s.hash(state);
},
}
Value::Array(ref v) => {
"array".hash(state);
v.len().hash(state);
for obj in v {
HashableValue(obj).hash(state);
}
},
}
Value::Object(ref m) => {
"object".hash(state);
m.len().hash(state);
Expand Down Expand Up @@ -116,5 +116,3 @@ mod tests {
assert!(coll.contains(&m));
}
}


0 comments on commit 03d13f8

Please sign in to comment.