diff --git a/http-client/src/tests.rs b/http-client/src/tests.rs
index 8ce53edb98..a7ba27cc8a 100644
--- a/http-client/src/tests.rs
+++ b/http-client/src/tests.rs
@@ -31,7 +31,7 @@ use crate::types::{
};
use crate::HttpClientBuilder;
use jsonrpsee_test_utils::helpers::*;
-use jsonrpsee_test_utils::types::Id;
+use jsonrpsee_test_utils::mocks::Id;
use jsonrpsee_test_utils::TimeoutFutureExt;
#[tokio::test]
diff --git a/http-server/src/tests.rs b/http-server/src/tests.rs
index f122128e76..55dcc9834c 100644
--- a/http-server/src/tests.rs
+++ b/http-server/src/tests.rs
@@ -32,7 +32,7 @@ use crate::types::error::{CallError, Error};
use crate::{server::StopHandle, HttpServerBuilder, RpcModule};
use jsonrpsee_test_utils::helpers::*;
-use jsonrpsee_test_utils::types::{Id, StatusCode, TestContext};
+use jsonrpsee_test_utils::mocks::{Id, StatusCode, TestContext};
use jsonrpsee_test_utils::TimeoutFutureExt;
use serde_json::Value as JsonValue;
use tokio::task::JoinHandle;
diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml
index 05bdf3371e..e4bfac7f27 100644
--- a/test-utils/Cargo.toml
+++ b/test-utils/Cargo.toml
@@ -15,6 +15,6 @@ hyper = { version = "0.14.10", features = ["full"] }
log = "0.4"
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = "1"
-soketto = "0.7"
+soketto = { version = "0.7", features = ["http"] }
tokio = { version = "1", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.6", features = ["compat"] }
diff --git a/test-utils/src/helpers.rs b/test-utils/src/helpers.rs
index 06cbe068b3..7cb3e7521f 100644
--- a/test-utils/src/helpers.rs
+++ b/test-utils/src/helpers.rs
@@ -24,7 +24,7 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
-use crate::types::{Body, HttpResponse, Id, Uri};
+use crate::mocks::{Body, HttpResponse, Id, Uri};
use hyper::service::{make_service_fn, service_fn};
use hyper::{Request, Response, Server};
use serde_json::Value;
diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs
index 6f6133eeaa..c47211bc62 100644
--- a/test-utils/src/lib.rs
+++ b/test-utils/src/lib.rs
@@ -32,7 +32,7 @@ use std::{future::Future, time::Duration};
use tokio::time::{timeout, Timeout};
pub mod helpers;
-pub mod types;
+pub mod mocks;
/// Helper extension trait which allows to limit execution time for the futures.
/// It is helpful in tests to ensure that no future will ever get stuck forever.
diff --git a/test-utils/src/types.rs b/test-utils/src/mocks.rs
similarity index 84%
rename from test-utils/src/types.rs
rename to test-utils/src/mocks.rs
index 5d3da21d6f..c3bb183ba0 100644
--- a/test-utils/src/types.rs
+++ b/test-utils/src/mocks.rs
@@ -34,10 +34,8 @@ use futures_util::{
stream::{self, StreamExt},
};
use serde::{Deserialize, Serialize};
-use soketto::handshake::{self, server::Response, Error as SokettoError, Server};
-use std::io;
-use std::net::SocketAddr;
-use std::time::Duration;
+use soketto::handshake::{self, http::is_upgrade_request, server::Response, Error as SokettoError, Server};
+use std::{io, net::SocketAddr, time::Duration};
use tokio::net::TcpStream;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
@@ -314,3 +312,60 @@ async fn connection_task(socket: tokio::net::TcpStream, mode: ServerMode, mut ex
}
}
}
+
+// Run a WebSocket server running on localhost that redirects requests for testing.
+// Requests to any url except for `/myblock/two` will redirect one or two times (HTTP 301) and eventually end up in `/myblock/two`.
+pub fn ws_server_with_redirect(other_server: String) -> String {
+ let addr = ([127, 0, 0, 1], 0).into();
+
+ let service = hyper::service::make_service_fn(move |_| {
+ let other_server = other_server.clone();
+ async move {
+ Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
+ let other_server = other_server.clone();
+ async move { handler(req, other_server).await }
+ }))
+ }
+ });
+ let server = hyper::Server::bind(&addr).serve(service);
+ let addr = server.local_addr();
+
+ tokio::spawn(async move { server.await });
+ format!("ws://{}", addr)
+}
+
+/// Handle incoming HTTP Requests.
+async fn handler(
+ req: hyper::Request
,
+ other_server: String,
+) -> Result, soketto::BoxedError> {
+ if is_upgrade_request(&req) {
+ log::debug!("{:?}", req);
+
+ match req.uri().path() {
+ "/myblock/two" => {
+ let response = hyper::Response::builder()
+ .status(301)
+ .header("Location", other_server)
+ .body(Body::empty())
+ .unwrap();
+ Ok(response)
+ }
+ "/myblock/one" => {
+ let response =
+ hyper::Response::builder().status(301).header("Location", "two").body(Body::empty()).unwrap();
+ Ok(response)
+ }
+ _ => {
+ let response = hyper::Response::builder()
+ .status(301)
+ .header("Location", "/myblock/one")
+ .body(Body::empty())
+ .unwrap();
+ Ok(response)
+ }
+ }
+ } else {
+ panic!("expect upgrade to WS");
+ }
+}
diff --git a/ws-client/Cargo.toml b/ws-client/Cargo.toml
index 11074d919c..b57b895e0d 100644
--- a/ws-client/Cargo.toml
+++ b/ws-client/Cargo.toml
@@ -18,15 +18,16 @@ arrayvec = "0.7.1"
async-trait = "0.1"
fnv = "1"
futures = { version = "0.3.14", default-features = false, features = ["std"] }
+http = "0.2"
jsonrpsee-types = { path = "../types", version = "0.3.0" }
log = "0.4"
+pin-project = "1"
+rustls-native-certs = "0.5.0"
serde = "1"
serde_json = "1"
soketto = "0.7"
-pin-project = "1"
thiserror = "1"
-url = "2"
-rustls-native-certs = "0.5.0"
[dev-dependencies]
jsonrpsee-test-utils = { path = "../test-utils" }
+env_logger = "0.9"
diff --git a/ws-client/src/client.rs b/ws-client/src/client.rs
index 048c99a83d..1a1dfc8cd4 100644
--- a/ws-client/src/client.rs
+++ b/ws-client/src/client.rs
@@ -24,7 +24,7 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
-use crate::transport::{Receiver as WsReceiver, Sender as WsSender, Target, WsTransportClientBuilder};
+use crate::transport::{Receiver as WsReceiver, Sender as WsSender, WsHandshakeError, WsTransportClientBuilder};
use crate::types::{
traits::{Client, SubscriptionClient},
v2::{Id, Notification, NotificationSer, ParamsSer, RequestSer, Response, RpcError, SubscriptionResponse},
@@ -46,10 +46,13 @@ use futures::{
prelude::*,
sink::SinkExt,
};
+use http::uri::{InvalidUri, Uri};
use tokio::sync::Mutex;
use serde::de::DeserializeOwned;
-use std::{borrow::Cow, time::Duration};
+use std::{borrow::Cow, convert::TryInto, time::Duration};
+
+pub use soketto::handshake::client::Header;
/// Wrapper over a [`oneshot::Receiver`](futures::channel::oneshot::Receiver) that reads
/// the underlying channel once and then stores the result in String.
@@ -109,6 +112,7 @@ pub struct WsClientBuilder<'a> {
origin_header: Option>,
max_concurrent_requests: usize,
max_notifs_per_subscription: usize,
+ max_redirections: usize,
}
impl<'a> Default for WsClientBuilder<'a> {
@@ -121,6 +125,7 @@ impl<'a> Default for WsClientBuilder<'a> {
origin_header: None,
max_concurrent_requests: 256,
max_notifs_per_subscription: 1024,
+ max_redirections: 5,
}
}
}
@@ -151,8 +156,8 @@ impl<'a> WsClientBuilder<'a> {
}
/// Set origin header to pass during the handshake.
- pub fn origin_header(mut self, origin: &'a str) -> Self {
- self.origin_header = Some(Cow::Borrowed(origin));
+ pub fn origin_header(mut self, origin: Cow<'a, str>) -> Self {
+ self.origin_header = Some(origin);
self
}
@@ -176,18 +181,19 @@ impl<'a> WsClientBuilder<'a> {
self
}
+ /// Set the max number of redirections to perform until a connection is regarded as failed.
+ pub fn max_redirections(mut self, redirect: usize) -> Self {
+ self.max_redirections = redirect;
+ self
+ }
+
/// Build the client with specified URL to connect to.
- /// If the port number is missing from the URL, the default port number is used.
- ///
- ///
- /// `ws://host` - port 80 is used
- ///
- /// `wss://host` - port 443 is used
+ /// You must provide the port number in the URL.
///
/// ## Panics
///
/// Panics if being called outside of `tokio` runtime context.
- pub async fn build(self, url: &'a str) -> Result {
+ pub async fn build(self, uri: &'a str) -> Result {
let certificate_store = self.certificate_store;
let max_capacity_per_subscription = self.max_notifs_per_subscription;
let max_concurrent_requests = self.max_concurrent_requests;
@@ -195,12 +201,15 @@ impl<'a> WsClientBuilder<'a> {
let (to_back, from_front) = mpsc::channel(self.max_concurrent_requests);
let (err_tx, err_rx) = oneshot::channel();
+ let uri: Uri = uri.parse().map_err(|e: InvalidUri| Error::Transport(e.into()))?;
+
let builder = WsTransportClientBuilder {
certificate_store,
- target: Target::parse(url).map_err(|e| Error::Transport(e.into()))?,
+ target: uri.try_into().map_err(|e: WsHandshakeError| Error::Transport(e.into()))?,
timeout: self.connection_timeout,
origin_header: self.origin_header,
max_request_body_size: self.max_request_body_size,
+ max_redirections: self.max_redirections,
};
let (sender, receiver) = builder.build().await.map_err(|e| Error::Transport(e.into()))?;
diff --git a/ws-client/src/tests.rs b/ws-client/src/tests.rs
index 2b9f74cbee..b2b0261826 100644
--- a/ws-client/src/tests.rs
+++ b/ws-client/src/tests.rs
@@ -32,7 +32,7 @@ use crate::types::{
};
use crate::WsClientBuilder;
use jsonrpsee_test_utils::helpers::*;
-use jsonrpsee_test_utils::types::{Id, WebSocketTestServer};
+use jsonrpsee_test_utils::mocks::{Id, WebSocketTestServer};
use jsonrpsee_test_utils::TimeoutFutureExt;
use serde_json::Value as JsonValue;
@@ -263,3 +263,34 @@ fn assert_error_response(err: Error, exp: ErrorObject) {
e => panic!("Expected error: \"{}\", got: {:?}", err, e),
};
}
+
+#[tokio::test]
+async fn redirections() {
+ let _ = env_logger::try_init();
+ let expected = "abc 123";
+ let server = WebSocketTestServer::with_hardcoded_response(
+ "127.0.0.1:0".parse().unwrap(),
+ ok_response(expected.into(), Id::Num(0)),
+ )
+ .with_default_timeout()
+ .await
+ .unwrap();
+
+ let server_url = format!("ws://{}", server.local_addr());
+ let redirect_url = jsonrpsee_test_utils::mocks::ws_server_with_redirect(server_url);
+
+ // The client will first connect to a server that only performs re-directions and finally
+ // redirect to another server to complete the handshake.
+ let client = WsClientBuilder::default().build(&redirect_url).with_default_timeout().await;
+ // It's an ok client
+ let client = match client {
+ Ok(Ok(client)) => client,
+ Ok(Err(e)) => panic!("WsClient builder failed with: {:?}", e),
+ Err(e) => panic!("WsClient builder timed out with: {:?}", e),
+ };
+ // It's connected
+ assert!(client.is_connected());
+ // It works
+ let response = client.request::("anything", ParamsSer::NoParams).with_default_timeout().await.unwrap();
+ assert_eq!(response.unwrap(), String::from(expected));
+}
diff --git a/ws-client/src/transport.rs b/ws-client/src/transport.rs
index 7a00ca6c1a..e18f34229f 100644
--- a/ws-client/src/transport.rs
+++ b/ws-client/src/transport.rs
@@ -24,12 +24,21 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
+use crate::stream::EitherStream;
use arrayvec::ArrayVec;
use futures::io::{BufReader, BufWriter};
-use futures::prelude::*;
+use http::Uri;
use soketto::connection;
-use soketto::handshake::client::{Client as WsRawClient, Header, ServerResponse};
-use std::{borrow::Cow, io, net::SocketAddr, sync::Arc, time::Duration};
+use soketto::handshake::client::{Client as WsHandshakeClient, Header, ServerResponse};
+use std::convert::TryInto;
+use std::{
+ borrow::Cow,
+ convert::TryFrom,
+ io,
+ net::{SocketAddr, ToSocketAddrs},
+ sync::Arc,
+ time::Duration,
+};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_rustls::{
@@ -39,7 +48,7 @@ use tokio_rustls::{
TlsConnector,
};
-type TlsOrPlain = crate::stream::EitherStream>;
+type TlsOrPlain = EitherStream>;
/// Sending end of WebSocket transport.
#[derive(Debug)]
@@ -67,6 +76,8 @@ pub struct WsTransportClientBuilder<'a> {
pub origin_header: Option>,
/// Max payload size
pub max_request_body_size: u32,
+ /// Max number of redirections.
+ pub max_redirections: usize,
}
/// Stream mode, either plain TCP or TLS.
@@ -95,42 +106,42 @@ pub enum CertificateStore {
#[derive(Debug, Error)]
pub enum WsHandshakeError {
/// Failed to load system certs
- #[error("Failed to load system certs: {}", 0)]
+ #[error("Failed to load system certs: {0}")]
CertificateStore(io::Error),
/// Invalid URL.
- #[error("Invalid url: {}", 0)]
+ #[error("Invalid URL: {0}")]
Url(Cow<'static, str>),
/// Error when opening the TCP socket.
- #[error("Error when opening the TCP socket: {}", 0)]
+ #[error("Error when opening the TCP socket: {0}")]
Io(io::Error),
/// Error in the transport layer.
- #[error("Error in the WebSocket handshake: {}", 0)]
+ #[error("Error in the WebSocket handshake: {0}")]
Transport(#[source] soketto::handshake::Error),
/// Invalid DNS name error for TLS
- #[error("Invalid DNS name: {}", 0)]
+ #[error("Invalid DNS name: {0}")]
InvalidDnsName(#[source] InvalidDNSNameError),
- /// RawServer rejected our handshake.
- #[error("Connection rejected with status code: {}", status_code)]
+ /// Server rejected the handshake.
+ #[error("Connection rejected with status code: {status_code}")]
Rejected {
/// HTTP status code that the server returned.
status_code: u16,
},
/// Timeout while trying to connect.
- #[error("Connection timeout exceeded: {}", 0)]
+ #[error("Connection timeout exceeded: {0:?}")]
Timeout(Duration),
/// Failed to resolve IP addresses for this hostname.
- #[error("Failed to resolve IP addresses for this hostname: {}", 0)]
+ #[error("Failed to resolve IP addresses for this hostname: {0}")]
ResolutionFailed(io::Error),
/// Couldn't find any IP address for this hostname.
- #[error("No IP address found for this hostname: {}", 0)]
+ #[error("No IP address found for this hostname: {0}")]
NoAddressFound(String),
}
@@ -186,79 +197,143 @@ impl<'a> WsTransportClientBuilder<'a> {
Mode::Plain => None,
};
- let mut err = None;
- for sockaddr in &self.target.sockaddrs {
- match self.try_connect(*sockaddr, &connector).await {
- Ok(res) => return Ok(res),
- Err(e) => {
- log::debug!("Failed to connect to sockaddr: {:?} with err: {:?}", sockaddr, e);
- err = Some(Err(e));
- }
- }
- }
- // NOTE(niklasad1): this is most likely unreachable because [`Url::socket_addrs`] doesn't
- // return an empty `Vec` if no socket address was found for the host name.
- err.unwrap_or(Err(WsHandshakeError::NoAddressFound(self.target.host)))
+ self.try_connect(connector).await
}
async fn try_connect(
- &self,
- sockaddr: SocketAddr,
- tls_connector: &Option,
+ self,
+ mut tls_connector: Option,
) -> Result<(Sender, Receiver), WsHandshakeError> {
- // Try establish the TCP connection.
- let tcp_stream = {
- let socket = TcpStream::connect(sockaddr);
- let timeout = tokio::time::sleep(self.timeout);
- futures::pin_mut!(socket, timeout);
- match future::select(socket, timeout).await {
- future::Either::Left((socket, _)) => {
- let socket = socket?;
- if let Err(err) = socket.set_nodelay(true) {
- log::warn!("set nodelay failed: {:?}", err);
- }
- match tls_connector {
- None => TlsOrPlain::Plain(socket),
- Some(connector) => {
- let dns_name = DNSNameRef::try_from_ascii_str(&self.target.host)?;
- let tls_stream = connector.connect(dns_name, socket).await?;
- TlsOrPlain::Tls(tls_stream)
- }
- }
- }
- future::Either::Right((_, _)) => return Err(WsHandshakeError::Timeout(self.timeout)),
- }
- };
-
- log::debug!("Connecting to target: {:?}", self.target);
- let mut client = WsRawClient::new(
- BufReader::new(BufWriter::new(tcp_stream)),
- &self.target.host_header,
- &self.target.path_and_query,
- );
-
+ let mut target = self.target;
let mut headers: ArrayVec = ArrayVec::new();
+ let mut err = None;
if let Some(origin) = self.origin_header.as_ref() {
headers.push(Header { name: "Origin", value: origin.as_bytes() });
}
- client.set_headers(&headers);
+ for _ in 0..self.max_redirections {
+ log::debug!("Connecting to target: {:?}", target);
+
+ // The sockaddrs might get reused if the server replies with a relative URI.
+ let sockaddrs = std::mem::take(&mut target.sockaddrs);
+ for sockaddr in &sockaddrs {
+ let tcp_stream = match connect(*sockaddr, self.timeout, &target.host, &tls_connector).await {
+ Ok(stream) => stream,
+ Err(e) => {
+ log::debug!("Failed to connect to sockaddr: {:?}", sockaddr);
+ err = Some(Err(e));
+ continue;
+ }
+ };
+ let mut client = WsHandshakeClient::new(
+ BufReader::new(BufWriter::new(tcp_stream)),
+ &target.host_header,
+ &target.path_and_query,
+ );
+
+ client.set_headers(&headers);
+
+ // Perform the initial handshake.
+ match client.handshake().await {
+ Ok(ServerResponse::Accepted { .. }) => {
+ log::info!("Connection established to target: {:?}", target);
+ let mut builder = client.into_builder();
+ builder.set_max_message_size(self.max_request_body_size as usize);
+ let (sender, receiver) = builder.finish();
+ return Ok((Sender { inner: sender }, Receiver { inner: receiver }));
+ }
- // Perform the initial handshake.
- match client.handshake().await? {
- ServerResponse::Accepted { .. } => {}
- ServerResponse::Rejected { status_code } | ServerResponse::Redirect { status_code, .. } => {
- // TODO: HTTP redirects also lead here #339.
- return Err(WsHandshakeError::Rejected { status_code });
+ Ok(ServerResponse::Rejected { status_code }) => {
+ log::debug!("Connection rejected: {:?}", status_code);
+ err = Some(Err(WsHandshakeError::Rejected { status_code }));
+ }
+ Ok(ServerResponse::Redirect { status_code, location }) => {
+ log::debug!("Redirection: status_code: {}, location: {}", status_code, location);
+ match location.parse::() {
+ // redirection with absolute path => need to lookup.
+ Ok(uri) => {
+ // Absolute URI.
+ if uri.scheme().is_some() {
+ target = uri.try_into()?;
+ tls_connector = match target.mode {
+ Mode::Tls => {
+ let mut client_config = ClientConfig::default();
+ if let CertificateStore::Native = self.certificate_store {
+ client_config.root_store = rustls_native_certs::load_native_certs()
+ .map_err(|(_, e)| WsHandshakeError::CertificateStore(e))?;
+ }
+ Some(Arc::new(client_config).into())
+ }
+ Mode::Plain => None,
+ };
+ break;
+ }
+ // Relative URI.
+ else {
+ // Replace the entire path_and_query if `location` starts with `/` or `//`.
+ if location.starts_with('/') {
+ target.path_and_query = location;
+ } else {
+ match target.path_and_query.rfind('/') {
+ Some(offset) => {
+ target.path_and_query.replace_range(offset + 1.., &location)
+ }
+ None => {
+ err = Some(Err(WsHandshakeError::Url(
+ format!(
+ "path_and_query: {}; this is a bug it must contain `/` please open issue",
+ location
+ )
+ .into(),
+ )));
+ continue;
+ }
+ };
+ }
+ target.sockaddrs = sockaddrs;
+ break;
+ }
+ }
+ Err(e) => {
+ err = Some(Err(WsHandshakeError::Url(e.to_string().into())));
+ }
+ };
+ }
+ Err(e) => {
+ err = Some(Err(e.into()));
+ }
+ };
}
}
+ err.unwrap_or(Err(WsHandshakeError::NoAddressFound(target.host)))
+ }
+}
- // If the handshake succeeded, return.
- let mut builder = client.into_builder();
- builder.set_max_message_size(self.max_request_body_size as usize);
- let (sender, receiver) = builder.finish();
- Ok((Sender { inner: sender }, Receiver { inner: receiver }))
+async fn connect(
+ sockaddr: SocketAddr,
+ timeout_dur: Duration,
+ host: &str,
+ tls_connector: &Option,
+) -> Result>, WsHandshakeError> {
+ let socket = TcpStream::connect(sockaddr);
+ let timeout = tokio::time::sleep(timeout_dur);
+ tokio::select! {
+ socket = socket => {
+ let socket = socket?;
+ if let Err(err) = socket.set_nodelay(true) {
+ log::warn!("set nodelay failed: {:?}", err);
+ }
+ match tls_connector {
+ None => Ok(TlsOrPlain::Plain(socket)),
+ Some(connector) => {
+ let dns_name = DNSNameRef::try_from_ascii_str(host)?;
+ let tls_stream = connector.connect(dns_name, socket).await?;
+ Ok(TlsOrPlain::Tls(tls_stream))
+ }
+ }
+ }
+ _ = timeout => Err(WsHandshakeError::Timeout(timeout_dur))
}
}
@@ -301,34 +376,32 @@ pub struct Target {
path_and_query: String,
}
-impl Target {
- /// Parse an URL String to a WebSocket address.
- pub fn parse(url: impl AsRef) -> Result {
- let url =
- url::Url::parse(url.as_ref()).map_err(|e| WsHandshakeError::Url(format!("Invalid URL: {}", e).into()))?;
- let mode = match url.scheme() {
- "ws" => Mode::Plain,
- "wss" => Mode::Tls,
+impl TryFrom for Target {
+ type Error = WsHandshakeError;
+
+ fn try_from(uri: Uri) -> Result {
+ let mode = match uri.scheme_str() {
+ Some("ws") => Mode::Plain,
+ Some("wss") => Mode::Tls,
_ => return Err(WsHandshakeError::Url("URL scheme not supported, expects 'ws' or 'wss'".into())),
};
- let host =
- url.host_str().map(ToOwned::to_owned).ok_or_else(|| WsHandshakeError::Url("No host in URL".into()))?;
- let port = url.port_or_known_default().ok_or_else(|| WsHandshakeError::Url("No port number in URL".into()))?;
+ let host = uri.host().map(ToOwned::to_owned).ok_or_else(|| WsHandshakeError::Url("No host in URL".into()))?;
+ let port = uri
+ .port_u16()
+ .ok_or_else(|| WsHandshakeError::Url("No port number in URL (default port is not supported)".into()))?;
let host_header = format!("{}:{}", host, port);
- let mut path_and_query = url.path().to_owned();
- if let Some(query) = url.query() {
- path_and_query.push('?');
- path_and_query.push_str(query);
- }
- // NOTE: `Url::socket_addrs` is using the default port if it's missing (ws:// - 80, wss:// - 443)
- let sockaddrs = url.socket_addrs(|| None).map_err(WsHandshakeError::ResolutionFailed)?;
- Ok(Self { sockaddrs, host, host_header, mode, path_and_query })
+ let parts = uri.into_parts();
+ let path_and_query = parts.path_and_query.ok_or_else(|| WsHandshakeError::Url("No path in URL".into()))?;
+ let sockaddrs = host_header.to_socket_addrs().map_err(WsHandshakeError::ResolutionFailed)?;
+ Ok(Self { sockaddrs: sockaddrs.collect(), host, host_header, mode, path_and_query: path_and_query.to_string() })
}
}
#[cfg(test)]
mod tests {
- use super::{Mode, Target, WsHandshakeError};
+ use super::{Mode, Target, Uri, WsHandshakeError};
+ use http::uri::InvalidUri;
+ use std::convert::TryInto;
fn assert_ws_target(target: Target, host: &str, host_header: &str, mode: Mode, path_and_query: &str) {
assert_eq!(&target.host, host);
@@ -337,53 +410,51 @@ mod tests {
assert_eq!(&target.path_and_query, path_and_query);
}
+ fn parse_target(uri: &str) -> Result {
+ uri.parse::().map_err(|e: InvalidUri| WsHandshakeError::Url(e.to_string().into()))?.try_into()
+ }
+
#[test]
fn ws_works() {
- let target = Target::parse("ws://127.0.0.1:9933").unwrap();
+ let target = parse_target("ws://127.0.0.1:9933").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:9933", Mode::Plain, "/");
}
#[test]
fn wss_works() {
- let target = Target::parse("wss://kusama-rpc.polkadot.io:443").unwrap();
+ let target = parse_target("wss://kusama-rpc.polkadot.io:443").unwrap();
assert_ws_target(target, "kusama-rpc.polkadot.io", "kusama-rpc.polkadot.io:443", Mode::Tls, "/");
}
#[test]
fn faulty_url_scheme() {
- let err = Target::parse("http://kusama-rpc.polkadot.io:443").unwrap_err();
+ let err = parse_target("http://kusama-rpc.polkadot.io:443").unwrap_err();
assert!(matches!(err, WsHandshakeError::Url(_)));
}
#[test]
fn faulty_port() {
- let err = Target::parse("ws://127.0.0.1:-43").unwrap_err();
+ let err = parse_target("ws://127.0.0.1:-43").unwrap_err();
assert!(matches!(err, WsHandshakeError::Url(_)));
- let err = Target::parse("ws://127.0.0.1:99999").unwrap_err();
+ let err = parse_target("ws://127.0.0.1:99999").unwrap_err();
assert!(matches!(err, WsHandshakeError::Url(_)));
}
- #[test]
- fn default_port_works() {
- let target = Target::parse("ws://127.0.0.1").unwrap();
- assert_ws_target(target, "127.0.0.1", "127.0.0.1:80", Mode::Plain, "/");
- }
-
#[test]
fn url_with_path_works() {
- let target = Target::parse("wss://127.0.0.1/my-special-path").unwrap();
+ let target = parse_target("wss://127.0.0.1:443/my-special-path").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my-special-path");
}
#[test]
fn url_with_query_works() {
- let target = Target::parse("wss://127.0.0.1/my?name1=value1&name2=value2").unwrap();
+ let target = parse_target("wss://127.0.0.1:443/my?name1=value1&name2=value2").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my?name1=value1&name2=value2");
}
#[test]
fn url_with_fragment_is_ignored() {
- let target = Target::parse("wss://127.0.0.1/my.htm#ignore").unwrap();
+ let target = parse_target("wss://127.0.0.1:443/my.htm#ignore").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my.htm");
}
}
diff --git a/ws-server/src/tests.rs b/ws-server/src/tests.rs
index d1b9cc7dbd..5739cd37ba 100644
--- a/ws-server/src/tests.rs
+++ b/ws-server/src/tests.rs
@@ -30,7 +30,7 @@ use crate::types::error::{CallError, Error};
use crate::{future::StopHandle, RpcModule, WsServerBuilder};
use anyhow::anyhow;
use jsonrpsee_test_utils::helpers::*;
-use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient, WebSocketTestError};
+use jsonrpsee_test_utils::mocks::{Id, TestContext, WebSocketTestClient, WebSocketTestError};
use jsonrpsee_test_utils::TimeoutFutureExt;
use serde_json::Value as JsonValue;
use std::fmt;