Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Refactor WS handling code #397

Merged
merged 7 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ethers-providers/src/pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ where
///
/// ### Note
/// Most providers treat `SubscriptionStream` IDs as global singletons.
/// Instanitating this directly with a known ID will likely cause any
/// Instantiating this directly with a known ID will likely cause any
/// existing streams with that ID to end. To avoid this, start a new stream
/// using [`Provider::subscribe`] instead of `SubscriptionStream::new`.
pub fn new(id: U256, provider: &'a Provider<P>) -> Result<Self, P::Error> {
Expand Down
218 changes: 122 additions & 96 deletions ethers-providers/src/transports/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,30 @@ use tokio_tungstenite::{
};
use tracing::{error, warn};

type Pending = oneshot::Sender<Result<serde_json::Value, JsonRpcError>>;
type Subscription = mpsc::UnboundedSender<serde_json::Value>;

/// Instructions for the `WsServer`.
enum Instruction {
/// JSON-RPC request
Request {
id: u64,
request: String,
sender: Pending,
},
/// Create a new subscription
Subscribe { id: U256, sink: Subscription },
/// Cancel an existing subscription
Unsubscribe { id: U256 },
}

#[derive(Debug, serde::Deserialize)]
#[serde(untagged)]
enum Incoming {
Notification(Notification<serde_json::Value>),
Response(Response<serde_json::Value>),
}

/// A JSON-RPC Client over Websockets.
///
/// ```no_run
Expand All @@ -43,25 +67,7 @@ use tracing::{error, warn};
#[derive(Clone)]
pub struct Ws {
id: Arc<AtomicU64>,
requests: mpsc::UnboundedSender<TransportMessage>,
}

type Pending = oneshot::Sender<Result<serde_json::Value, JsonRpcError>>;
type Subscription = mpsc::UnboundedSender<serde_json::Value>;

enum TransportMessage {
Request {
id: u64,
request: String,
sender: Pending,
},
Subscribe {
id: U256,
sink: Subscription,
},
Unsubscribe {
id: U256,
},
instructions: mpsc::UnboundedSender<Instruction>,
}

impl Debug for Ws {
Expand Down Expand Up @@ -90,13 +96,13 @@ impl Ws {

Self {
id: Arc::new(AtomicU64::new(0)),
requests: sink,
instructions: sink,
}
}

/// Returns true if the WS connection is active, false otherwise
pub fn ready(&self) -> bool {
!self.requests.is_closed()
!self.instructions.is_closed()
}

/// Initializes a new WebSocket Client
Expand All @@ -107,8 +113,10 @@ impl Ws {
Ok(Self::new(ws))
}

fn send(&self, msg: TransportMessage) -> Result<(), ClientError> {
self.requests.unbounded_send(msg).map_err(to_client_error)
fn send(&self, msg: Instruction) -> Result<(), ClientError> {
self.instructions
.unbounded_send(msg)
.map_err(to_client_error)
}
}

Expand All @@ -125,14 +133,14 @@ impl JsonRpcClient for Ws {

// send the message
let (sender, receiver) = oneshot::channel();
let payload = TransportMessage::Request {
let payload = Instruction::Request {
id: next_id,
request: serde_json::to_string(&Request::new(next_id, method, params))?,
sender,
};

// send the data
self.send(payload).map_err(to_client_error)?;
self.send(payload)?;

// wait for the response
let res = receiver.await?;
Expand All @@ -150,21 +158,21 @@ impl PubsubClient for Ws {

fn subscribe<T: Into<U256>>(&self, id: T) -> Result<Self::NotificationStream, ClientError> {
let (sink, stream) = mpsc::unbounded();
self.send(TransportMessage::Subscribe {
self.send(Instruction::Subscribe {
id: id.into(),
sink,
})?;
Ok(stream)
}

fn unsubscribe<T: Into<U256>>(&self, id: T) -> Result<(), ClientError> {
self.send(TransportMessage::Unsubscribe { id: id.into() })
self.send(Instruction::Unsubscribe { id: id.into() })
}
}

struct WsServer<S> {
ws: Fuse<S>,
requests: Fuse<mpsc::UnboundedReceiver<TransportMessage>>,
instructions: Fuse<mpsc::UnboundedReceiver<Instruction>>,

pending: BTreeMap<u64, Pending>,
subscriptions: BTreeMap<U256, Subscription>,
Expand All @@ -179,12 +187,12 @@ where
+ Unpin,
{
/// Instantiates the Websocket Server
fn new(ws: S, requests: mpsc::UnboundedReceiver<TransportMessage>) -> Self {
fn new(ws: S, requests: mpsc::UnboundedReceiver<Instruction>) -> Self {
Self {
// Fuse the 2 steams together, so that we can `select` them in the
// Stream implementation
ws: ws.fuse(),
requests: requests.fuse(),
instructions: requests.fuse(),
pending: BTreeMap::default(),
subscriptions: BTreeMap::default(),
}
Expand All @@ -197,13 +205,13 @@ where
{
let f = async move {
loop {
match self.process().await {
match self.tick().await {
Err(ClientError::UnexpectedClose) => {
tracing::error!("{}", ClientError::UnexpectedClose);
break;
}
Err(_) => {
panic!("WS Server panic");
Err(e) => {
panic!("WS Server panic: {}", e);
}
_ => {}
}
Expand All @@ -213,63 +221,84 @@ where
tokio::spawn(f);
}

/// Processes 1 item selected from the incoming `requests` or `ws`
#[allow(clippy::single_match)]
async fn process(&mut self) -> Result<(), ClientError> {
futures_util::select! {
// Handle requests
msg = self.requests.select_next_some() => {
self.handle_request(msg).await?;
},
// Handle ws messages
msg = self.ws.next() => match msg {
Some(Ok(msg)) => self.handle_ws(msg).await?,
// TODO: Log the error?
Some(Err(_)) => {},
None => {
return Err(ClientError::UnexpectedClose);
},
}
};
// dispatch an RPC request
async fn service_request(
&mut self,
id: u64,
request: String,
sender: Pending,
) -> Result<(), ClientError> {
if self.pending.insert(id, sender).is_some() {
warn!("Replacing a pending request with id {:?}", id);
}

if let Err(e) = self.ws.send(Message::Text(request)).await {
error!("WS connection error: {:?}", e);
self.pending.remove(&id);
}
Ok(())
}

async fn handle_request(&mut self, msg: TransportMessage) -> Result<(), ClientError> {
match msg {
TransportMessage::Request {
/// Dispatch a subscription request
async fn service_subscribe(&mut self, id: U256, sink: Subscription) -> Result<(), ClientError> {
if self.subscriptions.insert(id, sink).is_some() {
warn!("Replacing already-registered subscription with id {:?}", id);
}
Ok(())
}

/// Dispatch a unsubscribe request
async fn service_unsubscribe(&mut self, id: U256) -> Result<(), ClientError> {
if self.subscriptions.remove(&id).is_none() {
warn!(
"Unsubscribing from non-existent subscription with id {:?}",
id
);
}
Ok(())
}

/// Dispatch an outgoing message
async fn service(&mut self, instruction: Instruction) -> Result<(), ClientError> {
match instruction {
Instruction::Request {
id,
request,
sender,
} => {
if self.pending.insert(id, sender).is_some() {
warn!("Replacing a pending request with id {:?}", id);
}
} => self.service_request(id, request, sender).await,
Instruction::Subscribe { id, sink } => self.service_subscribe(id, sink).await,
Instruction::Unsubscribe { id } => self.service_unsubscribe(id).await,
}
}

if let Err(e) = self.ws.send(Message::Text(request)).await {
error!("WS connection error: {:?}", e);
self.pending.remove(&id);
}
}
TransportMessage::Subscribe { id, sink } => {
if self.subscriptions.insert(id, sink).is_some() {
warn!("Replacing already-registered subscription with id {:?}", id);
async fn handle_ping(&mut self, inner: Vec<u8>) -> Result<(), ClientError> {
self.ws.send(Message::Pong(inner)).await?;
Ok(())
}

async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> {
match serde_json::from_str::<Incoming>(&inner) {
Err(_) => {}
Ok(Incoming::Response(resp)) => {
if let Some(request) = self.pending.remove(&resp.id) {
request
.send(resp.data.into_result())
.map_err(to_client_error)?;
}
}
TransportMessage::Unsubscribe { id } => {
if self.subscriptions.remove(&id).is_none() {
warn!(
"Unsubscribing from non-existent subscription with id {:?}",
id
);
Ok(Incoming::Notification(notification)) => {
let id = notification.params.subscription;
if let Some(stream) = self.subscriptions.get(&id) {
stream
.unbounded_send(notification.params.result)
.map_err(to_client_error)?;
}
}
};

}
Ok(())
}

async fn handle_ws(&mut self, resp: Message) -> Result<(), ClientError> {
async fn handle(&mut self, resp: Message) -> Result<(), ClientError> {
match resp {
Message::Text(inner) => self.handle_text(inner).await,
Message::Ping(inner) => self.handle_ping(inner).await,
Expand All @@ -280,28 +309,25 @@ where
}
}

async fn handle_ping(&mut self, inner: Vec<u8>) -> Result<(), ClientError> {
self.ws.send(Message::Pong(inner)).await?;
Ok(())
}

async fn handle_text(&mut self, inner: String) -> Result<(), ClientError> {
if let Ok(resp) = serde_json::from_str::<Response<serde_json::Value>>(&inner) {
if let Some(request) = self.pending.remove(&resp.id) {
request
.send(resp.data.into_result())
.map_err(to_client_error)?;
}
} else if let Ok(notification) =
serde_json::from_str::<Notification<serde_json::Value>>(&inner)
{
let id = notification.params.subscription;
if let Some(stream) = self.subscriptions.get(&id) {
stream
.unbounded_send(notification.params.result)
.map_err(to_client_error)?;
/// Processes 1 instruction or 1 incoming websocket message
#[allow(clippy::single_match)]
async fn tick(&mut self) -> Result<(), ClientError> {
futures_util::select! {
// Handle requests
instruction = self.instructions.select_next_some() => {
self.service(instruction).await?;
},
// Handle ws messages
resp = self.ws.next() => match resp {
Some(Ok(resp)) => self.handle(resp).await?,
// TODO: Log the error?
Some(Err(_)) => {},
None => {
return Err(ClientError::UnexpectedClose);
},
}
}
};

Ok(())
}
}
Expand Down