Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross-origin protection #375

Merged
merged 10 commits into from
Jun 18, 2021
Merged
2 changes: 1 addition & 1 deletion test-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ hyper = { version = "0.14", features = ["full"] }
log = "0.4"
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = "1"
soketto = "0.5"
soketto = "0.6"
tokio = { version = "1", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.6", features = ["compat"] }
6 changes: 3 additions & 3 deletions test-utils/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,12 @@ async fn server_backend(listener: tokio::net::TcpListener, mut exit: Receiver<()
async fn connection_task(socket: tokio::net::TcpStream, mode: ServerMode, mut exit: Receiver<()>) {
let mut server = Server::new(socket.compat());

let websocket_key = match server.receive_request().await {
Ok(req) => req.into_key(),
let key = match server.receive_request().await {
Ok(req) => req.key(),
Err(_) => return,
};

let accept = server.send_response(&Response::Accept { key: &websocket_key, protocol: None }).await;
let accept = server.send_response(&Response::Accept { key, protocol: None }).await;

if accept.is_err() {
return;
Expand Down
2 changes: 1 addition & 1 deletion types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ log = { version = "0.4", default-features = false }
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = { version = "1", default-features = false, features = ["alloc", "raw_value", "std"] }
thiserror = "1.0"
soketto = "0.5"
soketto = "0.6"
hyper = "0.14"
3 changes: 3 additions & 0 deletions types/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ pub enum Error {
/// Configured max number of request slots exceeded.
#[error("Configured max number of request slots exceeded")]
MaxSlotsExceeded,
/// List passed into `set_allowed_origins` was empty
#[error("Must set at least one allowed origin")]
EmptyAllowedOrigins,
/// Custom error.
#[error("Custom error: {0}")]
Custom(String),
Expand Down
2 changes: 1 addition & 1 deletion ws-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jsonrpsee-types = { path = "../types", version = "0.2.0" }
log = "0.4"
serde = "1"
serde_json = "1"
soketto = "0.5"
soketto = "0.6"
pin-project = "1"
thiserror = "1"
url = "2"
Expand Down
2 changes: 1 addition & 1 deletion ws-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ log = "0.4"
rustc-hash = "1.1.0"
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = { version = "1", features = ["raw_value"] }
soketto = "0.5"
soketto = "0.6"
tokio = { version = "1", features = ["net", "rt-multi-thread", "macros"] }
tokio-stream = { version = "0.1.1", features = ["net"] }
tokio-util = { version = "0.6", features = ["compat"] }
Expand Down
85 changes: 76 additions & 9 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl Server {
pub async fn start(self) {
let mut incoming = TcpListenerStream::new(self.listener);
let methods = Arc::new(self.methods);
let cfg = self.cfg;
// let cfg = self.cfg;
maciejhirsz marked this conversation as resolved.
Show resolved Hide resolved
let mut id = 0;

while let Some(socket) = incoming.next().await {
Expand All @@ -88,7 +88,7 @@ impl Server {
}
let methods = methods.clone();

tokio::spawn(background_task(socket, id, methods, cfg));
tokio::spawn(background_task(socket, id, methods, self.cfg.clone()));

id += 1;
}
Expand All @@ -105,14 +105,24 @@ async fn background_task(
// For each incoming background_task we perform a handshake.
let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat())));

let websocket_key = {
let key = {
let req = server.receive_request().await?;
req.into_key()

cfg.cors.verify_origin(req.headers().origin).map(|_| req.key())
};

// Here we accept the client unconditionally.
let accept = Response::Accept { key: &websocket_key, protocol: None };
server.send_response(&accept).await?;
match key {
Ok(key) => {
let accept = Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
}
Err(error) => {
let reject = Response::Reject { status_code: 403 };
server.send_response(&reject).await?;

return Err(error);
}
}

// And we can finally transition to a websocket background_task.
let (mut sender, mut receiver) = server.into_builder().finish();
Expand Down Expand Up @@ -179,18 +189,40 @@ async fn background_task(
}
}

#[derive(Debug, Clone)]
enum Cors {
maciejhirsz marked this conversation as resolved.
Show resolved Hide resolved
AllowAny,
AllowList(Arc<[String]>),
}

impl Cors {
fn verify_origin(&self, origin: Option<&[u8]>) -> Result<(), Error> {
if let (Cors::AllowList(list), Some(origin)) = (self, origin) {
if !list.iter().any(|o| o.as_bytes() == origin) {
let error = format!("Origin denied: {}", String::from_utf8_lossy(origin));
log::warn!("{}", error);
return Err(Error::Request(error));
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want/need to handle "*" patterns as well in the allowed origins?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think so

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we do.

Copy link
Contributor Author

@maciejhirsz maciejhirsz Jun 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 3 possible states here:

  1. There is no Origin header, which equates to the request being done either on the domain (no cross-origin shenanigans), or the request isn't coming from a browser at all (in which case the client can spoof the header to whatever it wants). We always allow these.
  2. The header has a protocol://hostname value, if so we check it against the list (unless we allow all origins).
  3. The header has a null value, which can be explicitly allowed by the list, but is generally advised against.

There is no * origin.

Copy link
Collaborator

@jsdw jsdw Jun 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you explained this in a previous comment which makes sense to me now :)

My initial thinking was that CORS responses can contain "*" in allowed origins, so perhaps we need to handle something like this.

I did a little reading about WebSockets (since I didn't really know anything about the WebSocket upgrade protocol) and, indeed, there is no such thing as CORS really when talking about WS connections by the sounds of it, and so the origin checking here can essentially take whatever form we like.

Given that, I have no problem with not handling * in an allowed origin; it could potentially be a future enhancement (so that you can say configure this to allow eg any connections from "https://*.mydomain.com") but I guess there's no need for that sort of thing in the first cut!


Ok(())
}
}

/// JSON-RPC Websocket server settings.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
struct Settings {
/// Maximum size in bytes of a request.
max_request_body_size: u32,
/// Maximum number of incoming connections allowed.
max_connections: u64,
/// Cross-origin policy by which to accept or deny incoming requests.
cors: Cors,
}

impl Default for Settings {
fn default() -> Self {
Self { max_request_body_size: TEN_MB_SIZE_BYTES, max_connections: MAX_CONNECTIONS }
Self { max_request_body_size: TEN_MB_SIZE_BYTES, max_connections: MAX_CONNECTIONS, cors: Cors::AllowAny }
}
}

Expand All @@ -213,6 +245,41 @@ impl Builder {
self
}

/// Set a list of allowed origins. During the handshake, the `Origin` header will be
/// checked against the list, connections without a matching origin will be denied.
/// Values should include protocol.
///
/// ```rust
/// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default();
/// builder.set_allowed_origins(vec!["https://example.com"]);
/// ```
///
/// By default allows any `Origin`.
///
/// Will return an error if `list` is empty. Use [`allow_all_origins`](Builder::allow_all_origins) to restore the default.
pub fn set_allowed_origins<Origin, List>(mut self, list: List) -> Result<Self, Error>
where
List: IntoIterator<Item = Origin>,
Origin: Into<String>,
{
let list: Arc<_> = list.into_iter().map(Into::into).collect();

if list.len() == 0 {
return Err(Error::EmptyAllowedOrigins);
}

self.settings.cors = Cors::AllowList(list);

Ok(self)
}

/// Restores the default behavior of allowing connections with `Origin` header
/// containing any value. This will undo any list set by [`set_allowed_origins`](Builder::set_allowed_origins).
pub fn allow_all_origins(mut self) -> Self {
self.settings.cors = Cors::AllowAny;
self
}

/// Finalize the configuration of the server. Consumes the [`Builder`].
pub async fn build(self, addr: impl ToSocketAddrs) -> Result<Server, Error> {
let listener = TcpListener::bind(addr).await?;
Expand Down