Skip to content

Commit

Permalink
feat: cors support
Browse files Browse the repository at this point in the history
  • Loading branch information
joseluisq committed Jun 1, 2021
1 parent 991d4b8 commit af9a329
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 5 deletions.
267 changes: 267 additions & 0 deletions src/cors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
// CORS handler for incoming requests.
// -> Part of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/filters/cors.rs

use headers::{HeaderName, HeaderValue, Origin};
use http::header;
use std::{collections::HashSet, convert::TryFrom, sync::Arc};

/// It defines CORS instance.
#[derive(Clone, Debug)]
pub struct Cors {
allowed_headers: HashSet<HeaderName>,
max_age: Option<u64>,
allowed_methods: HashSet<http::Method>,
origins_str: String,
origins: Option<HashSet<HeaderValue>>,
}

/// It builds a new CORS instance.
pub fn new(origins_str: String) -> Option<Arc<Configured>> {
let cors = Cors::new(origins_str.clone());
let cors = if origins_str.is_empty() {
None
} else if origins_str == "*" {
Some(cors.allow_any_origin().allow_methods(vec!["GET", "HEAD"]))
} else {
let hosts = origins_str
.split(',')
.map(|s| s.trim().as_ref())
.collect::<Vec<_>>();

if hosts.is_empty() {
None
} else {
Some(cors.allow_origins(hosts).allow_methods(vec!["GET", "HEAD"]))
}
};
if cors.is_some() {
tracing::info!(
"enabled=true, allow_methods=[GET, HEAD], allow_origins={}",
origins_str
);
}
Cors::build(cors)
}

impl Cors {
/// Creates a new Cors instance.
pub fn new(origins_str: String) -> Self {
Self {
origins: None,
allowed_headers: HashSet::new(),
allowed_methods: HashSet::new(),
max_age: None,
origins_str,
}
}

/// Adds multiple methods to the existing list of allowed request methods.
///
/// # Panics
///
/// Panics if the provided argument is not a valid `http::Method`.
pub fn allow_methods<I>(mut self, methods: I) -> Self
where
I: IntoIterator,
http::Method: TryFrom<I::Item>,
{
let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
Ok(m) => m,
Err(_) => panic!("cors: illegal method"),
});
self.allowed_methods.extend(iter);
self
}

/// Sets that *any* `Origin` header is allowed.
///
/// # Warning
///
/// This can allow websites you didn't intend to access this resource,
/// it is usually better to set an explicit list.
pub fn allow_any_origin(mut self) -> Self {
self.origins = None;
self
}

/// Add multiple origins to the existing list of allowed `Origin`s.
///
/// # Panics
///
/// Panics if the provided argument is not a valid `Origin`.
pub fn allow_origins<I>(mut self, origins: I) -> Self
where
I: IntoIterator,
I::Item: IntoOrigin,
{
let iter = origins
.into_iter()
.map(IntoOrigin::into_origin)
.map(|origin| {
origin
.to_string()
.parse()
.expect("cors: Origin is always a valid HeaderValue")
});

self.origins.get_or_insert_with(HashSet::new).extend(iter);
self
}

/// Sets the `Access-Control-Max-Age` header.
/// TODO: we could enable this in the future.
///
/// # Example
///
/// ```
/// let cors = cors::new("*")
/// .max_age(30) // 30u32 seconds
/// .max_age(Duration::from_secs(30)); // or a Duration
/// ```
pub fn max_age(mut self, seconds: impl Seconds) -> Self {
self.max_age = Some(seconds.seconds());
self
}

/// Builds the `Cors` wrapper from the configured settings.
pub fn build(cors: Option<Cors>) -> Option<Arc<Configured>> {
cors.as_ref()?;
let cors = cors?;
Some(Arc::new(Configured { cors }))
}
}

impl Default for Cors {
fn default() -> Self {
Self::new("*".to_string())
}
}

#[derive(Clone, Debug)]
pub struct Configured {
cors: Cors,
}

#[derive(Debug)]
pub enum Validated {
Preflight(HeaderValue),
Simple(HeaderValue),
NotCors,
}

#[derive(Debug)]
pub enum Forbidden {
Origin,
Method,
Header,
}

impl Default for Forbidden {
fn default() -> Self {
Self::Origin
}
}

impl Configured {
pub fn check_request(
&self,
method: &http::Method,
headers: &http::HeaderMap,
) -> Result<Validated, Forbidden> {
match (headers.get(header::ORIGIN), method) {
(Some(origin), &http::Method::OPTIONS) => {
// OPTIONS requests are preflight CORS requests...

if !self.is_origin_allowed(origin) {
return Err(Forbidden::Origin);
}

if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
if !self.is_method_allowed(req_method) {
return Err(Forbidden::Method);
}
} else {
tracing::trace!(
"cors: preflight request missing access-control-request-method header"
);
return Err(Forbidden::Method);
}

if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
let headers = req_headers.to_str().map_err(|_| Forbidden::Header)?;
for header in headers.split(',') {
if !self.is_header_allowed(header.trim()) {
return Err(Forbidden::Header);
}
}
}

Ok(Validated::Preflight(origin.clone()))
}
(Some(origin), _) => {
// Any other method, simply check for a valid origin...
tracing::trace!("cors origin header: {:?}", origin);

if self.is_origin_allowed(origin) {
Ok(Validated::Simple(origin.clone()))
} else {
Err(Forbidden::Origin)
}
}
(None, _) => {
// No `ORIGIN` header means this isn't CORS!
Ok(Validated::NotCors)
}
}
}

pub fn is_method_allowed(&self, header: &HeaderValue) -> bool {
http::Method::from_bytes(header.as_bytes())
.map(|method| self.cors.allowed_methods.contains(&method))
.unwrap_or(false)
}

pub fn is_header_allowed(&self, header: &str) -> bool {
HeaderName::from_bytes(header.as_bytes())
.map(|header| self.cors.allowed_headers.contains(&header))
.unwrap_or(false)
}

pub fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
if let Some(ref allowed) = self.cors.origins {
allowed.contains(origin)
} else {
true
}
}
}

pub trait Seconds {
fn seconds(self) -> u64;
}

impl Seconds for u32 {
fn seconds(self) -> u64 {
self.into()
}
}

impl Seconds for ::std::time::Duration {
fn seconds(self) -> u64 {
self.as_secs()
}
}

pub trait IntoOrigin {
fn into_origin(self) -> Origin;
}

impl<'a> IntoOrigin for &'a str {
fn into_origin(self) -> Origin {
let mut parts = self.splitn(2, "://");
let scheme = parts.next().expect("cors::into_origin: missing url scheme");
let rest = parts.next().expect("cors::into_origin: missing url scheme");

Origin::try_from_parts(scheme, rest, None).expect("cors::into_origin: invalid Origin")
}
}
22 changes: 20 additions & 2 deletions src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use http::StatusCode;
use hyper::{Body, Request, Response};
use std::{future::Future, path::PathBuf};
use std::{future::Future, path::PathBuf, sync::Arc};

use crate::{compression, control_headers, static_files};
use crate::{compression, control_headers, cors, static_files};
use crate::{error_page, Error, Result};

// It defines options for a request handler.
pub struct RequestHandlerOpts {
pub root_dir: PathBuf,
pub compression: bool,
pub dir_listing: bool,
pub cors: Option<Arc<cors::Configured>>,
}

// It defines the main request handler for Hyper service request.
Expand All @@ -23,11 +25,27 @@ impl RequestHandler {
) -> impl Future<Output = Result<Response<Body>, Error>> + Send + 'a {
let method = req.method();
let headers = req.headers();

let root_dir = self.opts.root_dir.as_path();
let uri_path = req.uri().path();
let dir_listing = self.opts.dir_listing;

async move {
// CORS
if self.opts.cors.is_some() {
let cors = self.opts.cors.as_ref().unwrap();
match cors.check_request(method, headers) {
Ok(r) => {
tracing::debug!("cors ok: {:?}", r);
}
Err(e) => {
tracing::debug!("cors error kind: {:?}", e);
return error_page::get_error_response(method, &StatusCode::FORBIDDEN);
}
};
}

// Static files
match static_files::handle_request(method, headers, root_dir, uri_path, dir_listing)
.await
{
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ extern crate anyhow;
pub mod compression;
pub mod config;
pub mod control_headers;
pub mod cors;
pub mod error_page;
pub mod handler;
pub mod helpers;
Expand Down
8 changes: 5 additions & 3 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::handler::{RequestHandler, RequestHandlerOpts};
use crate::tls::{TlsAcceptor, TlsConfigBuilder};
use crate::Result;
use crate::{config::Config, service::RouterService};
use crate::{error_page, helpers, logger};
use crate::{cors, error_page, helpers, logger};

/// Define a multi-thread HTTP or HTTP/2 web server.
pub struct Server {
Expand Down Expand Up @@ -89,8 +89,6 @@ impl Server {
.set(helpers::read_file_content(opts.page50x.as_ref()))
.expect("page 50x is not initialized");

// TODO: CORS support

// Auto compression based on the `Accept-Encoding` header
let compression = opts.compression;
tracing::info!("auto compression compression: enabled={}", compression);
Expand All @@ -102,12 +100,16 @@ impl Server {
// Spawn a new Tokio asynchronous server task with its given options
let threads = self.threads;

// CORS support
let cors = cors::new(opts.cors_allow_origins.trim().to_string());

// Create a service router for Hyper
let router_service = RouterService::new(RequestHandler {
opts: RequestHandlerOpts {
root_dir,
compression,
dir_listing,
cors,
},
});

Expand Down

0 comments on commit af9a329

Please sign in to comment.