diff --git a/examples/Cargo.toml b/examples/Cargo.toml index ba23a6d3b1..49ce8c0e0a 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -50,3 +50,7 @@ path = "proc_macro.rs" [[example]] name = "core_client" path = "core_client.rs" + +[[example]] +name = "cors_server" +path = "cors_server.rs" \ No newline at end of file diff --git a/examples/cors_server.rs b/examples/cors_server.rs new file mode 100644 index 0000000000..a87a477de6 --- /dev/null +++ b/examples/cors_server.rs @@ -0,0 +1,82 @@ +// Copyright 2019-2022 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::net::SocketAddr; + +use jsonrpsee::http_server::{AccessControlBuilder, HttpServerBuilder, HttpServerHandle, RpcModule}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); + + // Start up a JSONPRC server that allows cross origin requests. + let (server_addr, _handle) = run_server().await?; + + // Print instructions for testing CORS from a browser. + println!("Run the following snippet in the developer console in any Website."); + println!( + r#" + fetch("http://{}", {{ + method: 'POST', + mode: 'cors', + headers: {{ 'Content-Type': 'application/json' }}, + body: JSON.stringify({{ + jsonrpc: '2.0', + method: 'say_hello', + id: 1 + }}) + }}).then(res => {{ + console.log("Response:", res); + return res.text() + }}).then(body => {{ + console.log("Response Body:", body) + }}); + "#, + server_addr + ); + + futures::future::pending().await +} + +async fn run_server() -> anyhow::Result<(SocketAddr, HttpServerHandle)> { + let acl = AccessControlBuilder::new().allow_all_headers().allow_all_origins().allow_all_hosts().build(); + + let server = HttpServerBuilder::default().set_access_control(acl).build("127.0.0.1:0".parse::()?)?; + + let mut module = RpcModule::new(()); + module.register_method("say_hello", |_, _| { + println!("say_hello method called!"); + Ok("Hello there!!") + })?; + + let addr = server.local_addr()?; + let server_handle = server.start(module)?; + + Ok((addr, server_handle)) +} diff --git a/http-server/src/access_control/cors.rs b/http-server/src/access_control/cors.rs index 3739077d0b..ba807743bf 100644 --- a/http-server/src/access_control/cors.rs +++ b/http-server/src/access_control/cors.rs @@ -31,6 +31,7 @@ use std::{fmt, ops}; use crate::access_control::hosts::{Host, Port}; use crate::access_control::matcher::{Matcher, Pattern}; +use jsonrpsee_core::Cow; use lazy_static::lazy_static; use unicase::Ascii; @@ -169,6 +170,16 @@ pub enum AccessControlAllowHeaders { Any, } +impl AccessControlAllowHeaders { + /// Return an appropriate value for the CORS header "Access-Control-Allow-Headers". + pub fn to_cors_header_value(&self) -> Cow<'_, str> { + match self { + AccessControlAllowHeaders::Any => "*".into(), + AccessControlAllowHeaders::Only(headers) => headers.join(", ").into(), + } + } +} + /// CORS response headers #[derive(Debug, Clone, PartialEq, Eq)] pub enum AllowCors { diff --git a/http-server/src/access_control/mod.rs b/http-server/src/access_control/mod.rs index 8c0fd0d76b..b17d616443 100644 --- a/http-server/src/access_control/mod.rs +++ b/http-server/src/access_control/mod.rs @@ -83,6 +83,11 @@ impl AccessControl { }); header == cors::AllowCors::Invalid && !self.continue_on_invalid_cors } + + /// Return the allowed headers we've set + pub(crate) fn allowed_headers(&self) -> &AccessControlAllowHeaders { + &self.allowed_headers + } } impl Default for AccessControl { diff --git a/http-server/src/response.rs b/http-server/src/response.rs index 8dbdb20ec6..9e3ef546bf 100644 --- a/http-server/src/response.rs +++ b/http-server/src/response.rs @@ -107,3 +107,12 @@ fn from_template>( pub fn ok_response(body: String) -> hyper::Response { from_template(hyper::StatusCode::OK, body, JSON) } + +/// Create a response for unsupported content type. +pub fn unsupported_content_type() -> hyper::Response { + from_template( + hyper::StatusCode::UNSUPPORTED_MEDIA_TYPE, + "Supplied content type is not allowed. Content-Type: application/json is required\n".to_owned(), + TEXT, + ) +} diff --git a/http-server/src/server.rs b/http-server/src/server.rs index fbbc86e779..c7df42b8ef 100644 --- a/http-server/src/server.rs +++ b/http-server/src/server.rs @@ -30,14 +30,16 @@ use std::net::{SocketAddr, TcpListener, ToSocketAddrs}; use std::pin::Pin; use std::task::{Context, Poll}; +use crate::response::{internal_error, malformed}; use crate::{response, AccessControl}; use futures_channel::mpsc; use futures_util::{future::join_all, stream::StreamExt, FutureExt}; +use hyper::header::{HeaderMap, HeaderValue}; use hyper::server::{conn::AddrIncoming, Builder as HyperBuilder}; use hyper::service::{make_service_fn, service_fn}; -use hyper::Error as HyperError; +use hyper::{Error as HyperError, Method}; use jsonrpsee_core::error::{Error, GenericTransportError}; -use jsonrpsee_core::http_helpers::read_body; +use jsonrpsee_core::http_helpers::{self, read_body}; use jsonrpsee_core::id_providers::NoopIdProvider; use jsonrpsee_core::middleware::Middleware; use jsonrpsee_core::server::helpers::{collect_batch_response, prepare_error, MethodSink}; @@ -305,113 +307,53 @@ impl Server { return Ok::<_, HyperError>(e); } - if let Err(e) = content_type_is_valid(&request) { - return Ok::<_, HyperError>(e); - } - - let (parts, body) = request.into_parts(); - - let (body, mut is_single) = match read_body(&parts.headers, body, max_request_body_size).await { - Ok(r) => r, - Err(GenericTransportError::TooLarge) => return Ok::<_, HyperError>(response::too_large()), - Err(GenericTransportError::Malformed) => return Ok::<_, HyperError>(response::malformed()), - Err(GenericTransportError::Inner(e)) => { - tracing::error!("Internal error reading request body: {}", e); - return Ok::<_, HyperError>(response::internal_error()); + // Only `POST` and `OPTIONS` methods are allowed. + match *request.method() { + // An OPTIONS request is a CORS preflight request. We've done our access check + // above so we just need to tell the browser that the request is OK. + Method::OPTIONS => { + let origin = match http_helpers::read_header_value(request.headers(), "origin") { + Some(origin) => origin, + None => return Ok(malformed()), + }; + let allowed_headers = access_control.allowed_headers().to_cors_header_value(); + let allowed_header_bytes = allowed_headers.as_bytes(); + + let res = hyper::Response::builder() + .header("access-control-allow-origin", origin) + .header("access-control-allow-methods", "POST") + .header("access-control-allow-headers", allowed_header_bytes) + .body(hyper::Body::empty()) + .unwrap_or_else(|e| { + tracing::error!("Error forming preflight response: {}", e); + internal_error() + }); + + Ok(res) } - }; - - let request_start = middleware.on_request(); - - // NOTE(niklasad1): it's a channel because it's needed for batch requests. - let (tx, mut rx) = mpsc::unbounded::(); - let sink = MethodSink::new_with_limit(tx, max_request_body_size); - - type Notif<'a> = Notification<'a, Option<&'a RawValue>>; - - // Single request or notification - if is_single { - if let Ok(req) = serde_json::from_slice::(&body) { - middleware.on_call(req.method.as_ref()); - - // NOTE: we don't need to track connection id on HTTP, so using hardcoded 0 here. - match methods.execute_with_resources(&sink, req, 0, &resources, &NoopIdProvider) { - Ok((name, MethodResult::Sync(success))) => { - middleware.on_result(name, success, request_start); - } - Ok((name, MethodResult::Async(fut))) => { - let success = fut.await; - - middleware.on_result(name, success, request_start); - } - Err(name) => { - middleware.on_result(name.as_ref(), false, request_start); - } + // The actual request. If it's a CORS request we need to remember to add + // the access-control-allow-origin header (despite preflight) to allow it + // to be read in a browser. + Method::POST if content_type_is_json(&request) => { + let origin = return_origin_if_different_from_host(request.headers()).cloned(); + let mut res = process_validated_request( + request, + middleware, + methods, + resources, + max_request_body_size, + ) + .await?; + + if let Some(origin) = origin { + res.headers_mut().insert("access-control-allow-origin", origin); } - } else if let Ok(_req) = serde_json::from_slice::(&body) { - return Ok::<_, HyperError>(response::ok_response("".into())); - } else { - let (id, code) = prepare_error(&body); - sink.send_error(id, code.into()); + Ok(res) } - - // Batch of requests or notifications - } else if let Ok(batch) = serde_json::from_slice::>(&body) { - if !batch.is_empty() { - let middleware = &middleware; - - join_all(batch.into_iter().filter_map( - move |req| match methods.execute_with_resources( - &sink, - req, - 0, - &resources, - &NoopIdProvider, - ) { - Ok((name, MethodResult::Sync(success))) => { - middleware.on_result(name, success, request_start); - None - } - Ok((name, MethodResult::Async(fut))) => Some(async move { - let success = fut.await; - middleware.on_result(name, success, request_start); - }), - Err(name) => { - middleware.on_result(name.as_ref(), false, request_start); - None - } - }, - )) - .await; - } else { - // "If the batch rpc call itself fails to be recognized as an valid JSON or as an - // Array with at least one value, the response from the Server MUST be a single - // Response object." – The Spec. - is_single = true; - sink.send_error(Id::Null, ErrorCode::InvalidRequest.into()); - } - } else if let Ok(_batch) = serde_json::from_slice::>(&body) { - return Ok::<_, HyperError>(response::ok_response("".into())); - } else { - // "If the batch rpc call itself fails to be recognized as an valid JSON or as an - // Array with at least one value, the response from the Server MUST be a single - // Response object." – The Spec. - is_single = true; - let (id, code) = prepare_error(&body); - sink.send_error(id, code.into()); + // Error scenarios: + Method::POST => Ok(response::unsupported_content_type()), + _ => Ok(response::method_not_allowed()), } - - // Closes the receiving half of a channel without dropping it. This prevents any further - // messages from being sent on the channel. - rx.close(); - let response = if is_single { - rx.next().await.expect("Sender is still alive managed by us above; qed") - } else { - collect_batch_response(rx).await - }; - tracing::debug!("[service_fn] sending back: {:?}", &response[..cmp::min(response.len(), 1024)]); - middleware.on_response(request_start); - Ok::<_, HyperError>(response::ok_response(response)) } })) } @@ -431,6 +373,20 @@ impl Server { } } +// Checks the origin and host headers. If they both exist, return the origin if it does not match the host. +// If one of them doesn't exist (origin most probably), or they are identical, return None. +fn return_origin_if_different_from_host(headers: &HeaderMap) -> Option<&HeaderValue> { + if let (Some(origin), Some(host)) = (headers.get("origin"), headers.get("host")) { + if origin != host { + Some(origin) + } else { + None + } + } else { + None + } +} + // Checks to that access control of the received request is the same as configured. fn access_control_is_valid( access_control: &AccessControl, @@ -449,11 +405,8 @@ fn access_control_is_valid( } /// Checks that content type of received request is valid for JSON-RPC. -fn content_type_is_valid(request: &hyper::Request) -> Result<(), hyper::Response> { - match *request.method() { - hyper::Method::POST if is_json(request.headers().get("content-type")) => Ok(()), - _ => Err(response::method_not_allowed()), - } +fn content_type_is_json(request: &hyper::Request) -> bool { + is_json(request.headers().get("content-type")) } /// Returns true if the `content_type` header indicates a valid JSON message. @@ -469,3 +422,110 @@ fn is_json(content_type: Option<&hyper::header::HeaderValue>) -> bool { _ => false, } } + +/// Process a verified request, it implies a POST request with content type JSON. +async fn process_validated_request( + request: hyper::Request, + middleware: impl Middleware, + methods: Methods, + resources: Resources, + max_request_body_size: u32, +) -> Result, HyperError> { + let (parts, body) = request.into_parts(); + + let (body, mut is_single) = match read_body(&parts.headers, body, max_request_body_size).await { + Ok(r) => r, + Err(GenericTransportError::TooLarge) => return Ok(response::too_large()), + Err(GenericTransportError::Malformed) => return Ok(response::malformed()), + Err(GenericTransportError::Inner(e)) => { + tracing::error!("Internal error reading request body: {}", e); + return Ok(response::internal_error()); + } + }; + + let request_start = middleware.on_request(); + + // NOTE(niklasad1): it's a channel because it's needed for batch requests. + let (tx, mut rx) = mpsc::unbounded::(); + let sink = MethodSink::new_with_limit(tx, max_request_body_size); + + type Notif<'a> = Notification<'a, Option<&'a RawValue>>; + + // Single request or notification + if is_single { + if let Ok(req) = serde_json::from_slice::(&body) { + middleware.on_call(req.method.as_ref()); + + // NOTE: we don't need to track connection id on HTTP, so using hardcoded 0 here. + match methods.execute_with_resources(&sink, req, 0, &resources, &NoopIdProvider) { + Ok((name, MethodResult::Sync(success))) => { + middleware.on_result(name, success, request_start); + } + Ok((name, MethodResult::Async(fut))) => { + let success = fut.await; + + middleware.on_result(name, success, request_start); + } + Err(name) => { + middleware.on_result(name.as_ref(), false, request_start); + } + } + } else if let Ok(_req) = serde_json::from_slice::(&body) { + return Ok::<_, HyperError>(response::ok_response("".into())); + } else { + let (id, code) = prepare_error(&body); + sink.send_error(id, code.into()); + } + + // Batch of requests or notifications + } else if let Ok(batch) = serde_json::from_slice::>(&body) { + if !batch.is_empty() { + let middleware = &middleware; + + join_all(batch.into_iter().filter_map(move |req| { + match methods.execute_with_resources(&sink, req, 0, &resources, &NoopIdProvider) { + Ok((name, MethodResult::Sync(success))) => { + middleware.on_result(name, success, request_start); + None + } + Ok((name, MethodResult::Async(fut))) => Some(async move { + let success = fut.await; + middleware.on_result(name, success, request_start); + }), + Err(name) => { + middleware.on_result(name.as_ref(), false, request_start); + None + } + } + })) + .await; + } else { + // "If the batch rpc call itself fails to be recognized as an valid JSON or as an + // Array with at least one value, the response from the Server MUST be a single + // Response object." – The Spec. + is_single = true; + sink.send_error(Id::Null, ErrorCode::InvalidRequest.into()); + } + } else if let Ok(_batch) = serde_json::from_slice::>(&body) { + return Ok(response::ok_response("".into())); + } else { + // "If the batch rpc call itself fails to be recognized as an valid JSON or as an + // Array with at least one value, the response from the Server MUST be a single + // Response object." – The Spec. + is_single = true; + let (id, code) = prepare_error(&body); + sink.send_error(id, code.into()); + } + + // Closes the receiving half of a channel without dropping it. This prevents any further + // messages from being sent on the channel. + rx.close(); + let response = if is_single { + rx.next().await.expect("Sender is still alive managed by us above; qed") + } else { + collect_batch_response(rx).await + }; + tracing::debug!("[service_fn] sending back: {:?}", &response[..cmp::min(response.len(), 1024)]); + middleware.on_response(request_start); + Ok(response::ok_response(response)) +} diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 931a582e91..5f1de14229 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -16,3 +16,4 @@ tokio = { version = "1.8", features = ["full"] } tracing = "0.1" serde = "1" serde_json = "1" +hyper = { version = "0.14", features = ["http1", "client"] } diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index a191f1f3ea..b8551cc68c 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -28,7 +28,7 @@ use std::net::SocketAddr; use std::time::Duration; use jsonrpsee::core::Error; -use jsonrpsee::http_server::{HttpServerBuilder, HttpServerHandle}; +use jsonrpsee::http_server::{AccessControl, HttpServerBuilder, HttpServerHandle}; use jsonrpsee::ws_server::{WsServerBuilder, WsServerHandle}; use jsonrpsee::RpcModule; @@ -117,7 +117,11 @@ pub async fn websocket_server() -> SocketAddr { } pub async fn http_server() -> (SocketAddr, HttpServerHandle) { - let server = HttpServerBuilder::default().build("127.0.0.1:0").unwrap(); + http_server_with_access_control(AccessControl::default()).await +} + +pub async fn http_server_with_access_control(acl: AccessControl) -> (SocketAddr, HttpServerHandle) { + let server = HttpServerBuilder::default().set_access_control(acl).build("127.0.0.1:0").unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| Ok("hello")).unwrap(); diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 874e504c12..5adc1cd42b 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use std::time::Duration; -use helpers::{http_server, websocket_server, websocket_server_with_subscription}; +use helpers::{http_server, http_server_with_access_control, websocket_server, websocket_server_with_subscription}; use jsonrpsee::core::client::{ClientT, Subscription, SubscriptionClientT}; use jsonrpsee::core::error::SubscriptionClosedReason; use jsonrpsee::core::{Error, JsonValue}; @@ -373,3 +373,140 @@ async fn ws_batch_works() { let responses: Vec = client.batch_request(batch).await.unwrap(); assert_eq!(responses, vec!["hello".to_string(), "hello".to_string()]); } + +#[tokio::test] +async fn http_unsupported_methods_dont_work() { + use hyper::{Body, Client, Method, Request}; + + let (server_addr, _handle) = http_server().await; + + let http_client = Client::new(); + let uri = format!("http://{}", server_addr); + + let req_is_client_error = |method| async { + let req = Request::builder() + .method(method) + .uri(&uri) + .header("content-type", "application/json") + .body(Body::from(r#"{ "jsonrpc": "2.0", method: "say_hello", "id": 1 }"#)) + .expect("request builder"); + + let res = http_client.request(req).await.unwrap(); + res.status().is_client_error() + }; + + for verb in [Method::GET, Method::PUT, Method::PATCH, Method::DELETE] { + assert!(req_is_client_error(verb).await); + } + for verb in [Method::POST] { + assert!(!req_is_client_error(verb).await); + } +} + +#[tokio::test] +async fn http_correct_content_type_required() { + use hyper::{Body, Client, Method, Request}; + + let (server_addr, _handle) = http_server().await; + + let http_client = Client::new(); + let uri = format!("http://{}", server_addr); + + // We don't set content-type at all + let req = Request::builder() + .method(Method::POST) + .uri(&uri) + .body(Body::from(r#"{ "jsonrpc": "2.0", method: "say_hello", "id": 1 }"#)) + .expect("request builder"); + + let res = http_client.request(req).await.unwrap(); + assert!(res.status().is_client_error()); + + // We use the wrong content-type + let req = Request::builder() + .method(Method::POST) + .uri(&uri) + .header("content-type", "application/text") + .body(Body::from(r#"{ "jsonrpc": "2.0", method: "say_hello", "id": 1 }"#)) + .expect("request builder"); + + let res = http_client.request(req).await.unwrap(); + assert!(res.status().is_client_error()); + + // We use the correct content-type + let req = Request::builder() + .method(Method::POST) + .uri(&uri) + .header("content-type", "application/json") + .body(Body::from(r#"{ "jsonrpc": "2.0", method: "say_hello", "id": 1 }"#)) + .expect("request builder"); + + let res = http_client.request(req).await.unwrap(); + assert!(res.status().is_success()); +} + +#[tokio::test] +async fn http_cors_preflight_works() { + use hyper::{Body, Client, Method, Request}; + use jsonrpsee::http_server::AccessControlBuilder; + + let acl = AccessControlBuilder::new().set_allowed_origins(vec!["https://foo.com"]).unwrap().build(); + let (server_addr, _handle) = http_server_with_access_control(acl).await; + + let http_client = Client::new(); + let uri = format!("http://{}", server_addr); + + // First, make a preflight request. + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#preflighted_requests for examples. + // See https://fetch.spec.whatwg.org/#http-cors-protocol for the spec. + let preflight_req = Request::builder() + .method(Method::OPTIONS) + .uri(&uri) + .header("host", "bar.com") // <- host that request is being sent _to_ + .header("origin", "https://foo.com") // <- where request is being sent _from_ + .header("access-control-request-method", "POST") + .header("access-control-request-headers", "content-type") + .body(Body::empty()) + .expect("preflight request builder"); + + let has = |v: &[String], s| v.iter().any(|v| v == s); + + let preflight_res = http_client.request(preflight_req).await.unwrap(); + let preflight_headers = preflight_res.headers(); + + let allow_origins = comma_separated_header_values(preflight_headers, "access-control-allow-origin"); + let allow_methods = comma_separated_header_values(preflight_headers, "access-control-allow-methods"); + let allow_headers = comma_separated_header_values(preflight_headers, "access-control-allow-headers"); + + // We expect the preflight response to tell us that our origin, methods and headers are all OK to use. + // If they aren't, the browser will not make the actual request. Note that if these `access-control-*` + // headers aren't return, the default is that the origin/method/headers are not allowed, I think. + assert!(preflight_res.status().is_success()); + assert!(has(&allow_origins, "https://foo.com") || has(&allow_origins, "*")); + assert!(has(&allow_methods, "post") || has(&allow_methods, "*")); + assert!(has(&allow_headers, "content-type") || has(&allow_headers, "*")); + + // Assuming that that was successful, we now make the actual request. No CORS headers are needed here + // as the browser checked their validity in the preflight request. + let req = Request::builder() + .method(Method::POST) + .uri(&uri) + .header("host", "bar.com") + .header("origin", "https://foo.com") + .header("content-type", "application/json") + .body(Body::from(r#"{ "jsonrpc": "2.0", method: "say_hello", "id": 1 }"#)) + .expect("actual request builder"); + + let res = http_client.request(req).await.unwrap(); + assert!(res.status().is_success()); + assert!(has(&allow_origins, "https://foo.com") || has(&allow_origins, "*")); +} + +fn comma_separated_header_values(headers: &hyper::HeaderMap, header: &str) -> Vec { + headers + .get_all(header) + .into_iter() + .flat_map(|value| value.to_str().unwrap().split(',').map(|val| val.trim())) + .map(|header| header.to_ascii_lowercase()) + .collect() +}