Skip to content

Commit

Permalink
Add support for wss (#360)
Browse files Browse the repository at this point in the history
Make rust-web3 work with an infura-like endpoint.
  • Loading branch information
en authored Jun 26, 2020
1 parent 06fa936 commit c06ce7d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 7 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ hyper-tls = { version = "0.4", optional = true }
native-tls = { version = "0.2", optional = true }
url = { version = "2.1.0", optional = true }
## WS
async-native-tls = { version = "0.3", optional = true }
async-std = { version = "1.5.0", optional = true }
soketto = { version = "0.4.1", optional = true }

Expand All @@ -50,6 +51,6 @@ async-std = { version = "1.5.0", features = ["attributes"] }
default = ["http", "ws", "tls"]
http = ["hyper", "url", "base64"]
tls = ["hyper-tls", "native-tls"]
ws = ["soketto", "async-std"]
ws = ["soketto", "async-std", "async-native-tls"]

[workspace]
82 changes: 76 additions & 6 deletions src/transports/ws.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! WebSocket Transport

use std::collections::BTreeMap;
use std::marker::Unpin;
use std::sync::{atomic, Arc};
use std::{fmt, pin::Pin};

Expand All @@ -14,10 +15,13 @@ use futures::{
task::{Context, Poll},
Future, FutureExt, Stream, StreamExt,
};
use futures::{AsyncRead, AsyncWrite};

use async_native_tls::TlsStream;
use async_std::net::TcpStream;
use soketto::connection;
use soketto::handshake::{Client, ServerResponse};
use url::Url;

impl From<soketto::handshake::Error> for Error {
fn from(err: soketto::handshake::Error) -> Self {
Expand All @@ -36,20 +40,85 @@ type BatchResult = error::Result<Vec<SingleResult>>;
type Pending = oneshot::Sender<BatchResult>;
type Subscription = mpsc::UnboundedSender<rpc::Value>;

/// Stream, either plain TCP or TLS.
enum MaybeTlsStream<S> {
/// Unencrypted socket stream.
Plain(S),
/// Encrypted socket stream.
Tls(TlsStream<S>),
}

impl<S> AsyncRead for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_read(cx, buf),
}
}
}

impl<S> AsyncWrite for MaybeTlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_flush(cx),
}
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_close(cx),
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_close(cx),
}
}
}

struct WsServerTask {
pending: BTreeMap<RequestId, Pending>,
subscriptions: BTreeMap<SubscriptionId, Subscription>,
sender: connection::Sender<TcpStream>,
receiver: connection::Receiver<TcpStream>,
sender: connection::Sender<MaybeTlsStream<TcpStream>>,
receiver: connection::Receiver<MaybeTlsStream<TcpStream>>,
}

impl WsServerTask {
/// Create new WebSocket transport.
pub async fn new(url: &str) -> error::Result<Self> {
let url = url.trim_start_matches("ws://");
let url = Url::parse(url)?;

let scheme = match url.scheme() {
s if s == "ws" || s == "wss" => s,
s => return Err(error::Error::Transport(format!("Wrong scheme: {}", s))),
};
let host = match url.host_str() {
Some(s) => s,
None => return Err(error::Error::Transport("Wrong host name".to_string())),
};
let port = url.port().unwrap_or(if scheme == "ws" { 80 } else { 443 });
let addrs = format!("{}:{}", host, port);

let stream = TcpStream::connect(addrs).await?;

let socket = if scheme == "wss" {
let stream = async_native_tls::connect(host, stream).await?;
MaybeTlsStream::Tls(stream)
} else {
MaybeTlsStream::Plain(stream)
};

let socket = TcpStream::connect(url).await?;
let mut client = Client::new(socket, url, "/");
let mut client = Client::new(socket, host, url.path());
let handshake = client.handshake();
let (sender, receiver) = match handshake.await? {
ServerResponse::Accepted { .. } => client.into_builder().finish(),
Expand Down Expand Up @@ -370,7 +439,8 @@ mod tests {
let addr = "127.0.0.1:3000";
async_std::task::spawn(server(addr));

let ws = WebSocket::new(addr).await.unwrap();
let endpoint = "ws://127.0.0.1:3000";
let ws = WebSocket::new(endpoint).await.unwrap();

// when
let res = ws.execute("eth_accounts", vec![rpc::Value::String("1".into())]);
Expand Down

0 comments on commit c06ce7d

Please sign in to comment.