Skip to content

Commit af9a329

Browse files
committed
feat: cors support
1 parent 991d4b8 commit af9a329

File tree

4 files changed

+293
-5
lines changed

4 files changed

+293
-5
lines changed

src/cors.rs

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
// CORS handler for incoming requests.
2+
// -> Part of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/filters/cors.rs
3+
4+
use headers::{HeaderName, HeaderValue, Origin};
5+
use http::header;
6+
use std::{collections::HashSet, convert::TryFrom, sync::Arc};
7+
8+
/// It defines CORS instance.
9+
#[derive(Clone, Debug)]
10+
pub struct Cors {
11+
allowed_headers: HashSet<HeaderName>,
12+
max_age: Option<u64>,
13+
allowed_methods: HashSet<http::Method>,
14+
origins_str: String,
15+
origins: Option<HashSet<HeaderValue>>,
16+
}
17+
18+
/// It builds a new CORS instance.
19+
pub fn new(origins_str: String) -> Option<Arc<Configured>> {
20+
let cors = Cors::new(origins_str.clone());
21+
let cors = if origins_str.is_empty() {
22+
None
23+
} else if origins_str == "*" {
24+
Some(cors.allow_any_origin().allow_methods(vec!["GET", "HEAD"]))
25+
} else {
26+
let hosts = origins_str
27+
.split(',')
28+
.map(|s| s.trim().as_ref())
29+
.collect::<Vec<_>>();
30+
31+
if hosts.is_empty() {
32+
None
33+
} else {
34+
Some(cors.allow_origins(hosts).allow_methods(vec!["GET", "HEAD"]))
35+
}
36+
};
37+
if cors.is_some() {
38+
tracing::info!(
39+
"enabled=true, allow_methods=[GET, HEAD], allow_origins={}",
40+
origins_str
41+
);
42+
}
43+
Cors::build(cors)
44+
}
45+
46+
impl Cors {
47+
/// Creates a new Cors instance.
48+
pub fn new(origins_str: String) -> Self {
49+
Self {
50+
origins: None,
51+
allowed_headers: HashSet::new(),
52+
allowed_methods: HashSet::new(),
53+
max_age: None,
54+
origins_str,
55+
}
56+
}
57+
58+
/// Adds multiple methods to the existing list of allowed request methods.
59+
///
60+
/// # Panics
61+
///
62+
/// Panics if the provided argument is not a valid `http::Method`.
63+
pub fn allow_methods<I>(mut self, methods: I) -> Self
64+
where
65+
I: IntoIterator,
66+
http::Method: TryFrom<I::Item>,
67+
{
68+
let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
69+
Ok(m) => m,
70+
Err(_) => panic!("cors: illegal method"),
71+
});
72+
self.allowed_methods.extend(iter);
73+
self
74+
}
75+
76+
/// Sets that *any* `Origin` header is allowed.
77+
///
78+
/// # Warning
79+
///
80+
/// This can allow websites you didn't intend to access this resource,
81+
/// it is usually better to set an explicit list.
82+
pub fn allow_any_origin(mut self) -> Self {
83+
self.origins = None;
84+
self
85+
}
86+
87+
/// Add multiple origins to the existing list of allowed `Origin`s.
88+
///
89+
/// # Panics
90+
///
91+
/// Panics if the provided argument is not a valid `Origin`.
92+
pub fn allow_origins<I>(mut self, origins: I) -> Self
93+
where
94+
I: IntoIterator,
95+
I::Item: IntoOrigin,
96+
{
97+
let iter = origins
98+
.into_iter()
99+
.map(IntoOrigin::into_origin)
100+
.map(|origin| {
101+
origin
102+
.to_string()
103+
.parse()
104+
.expect("cors: Origin is always a valid HeaderValue")
105+
});
106+
107+
self.origins.get_or_insert_with(HashSet::new).extend(iter);
108+
self
109+
}
110+
111+
/// Sets the `Access-Control-Max-Age` header.
112+
/// TODO: we could enable this in the future.
113+
///
114+
/// # Example
115+
///
116+
/// ```
117+
/// let cors = cors::new("*")
118+
/// .max_age(30) // 30u32 seconds
119+
/// .max_age(Duration::from_secs(30)); // or a Duration
120+
/// ```
121+
pub fn max_age(mut self, seconds: impl Seconds) -> Self {
122+
self.max_age = Some(seconds.seconds());
123+
self
124+
}
125+
126+
/// Builds the `Cors` wrapper from the configured settings.
127+
pub fn build(cors: Option<Cors>) -> Option<Arc<Configured>> {
128+
cors.as_ref()?;
129+
let cors = cors?;
130+
Some(Arc::new(Configured { cors }))
131+
}
132+
}
133+
134+
impl Default for Cors {
135+
fn default() -> Self {
136+
Self::new("*".to_string())
137+
}
138+
}
139+
140+
#[derive(Clone, Debug)]
141+
pub struct Configured {
142+
cors: Cors,
143+
}
144+
145+
#[derive(Debug)]
146+
pub enum Validated {
147+
Preflight(HeaderValue),
148+
Simple(HeaderValue),
149+
NotCors,
150+
}
151+
152+
#[derive(Debug)]
153+
pub enum Forbidden {
154+
Origin,
155+
Method,
156+
Header,
157+
}
158+
159+
impl Default for Forbidden {
160+
fn default() -> Self {
161+
Self::Origin
162+
}
163+
}
164+
165+
impl Configured {
166+
pub fn check_request(
167+
&self,
168+
method: &http::Method,
169+
headers: &http::HeaderMap,
170+
) -> Result<Validated, Forbidden> {
171+
match (headers.get(header::ORIGIN), method) {
172+
(Some(origin), &http::Method::OPTIONS) => {
173+
// OPTIONS requests are preflight CORS requests...
174+
175+
if !self.is_origin_allowed(origin) {
176+
return Err(Forbidden::Origin);
177+
}
178+
179+
if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
180+
if !self.is_method_allowed(req_method) {
181+
return Err(Forbidden::Method);
182+
}
183+
} else {
184+
tracing::trace!(
185+
"cors: preflight request missing access-control-request-method header"
186+
);
187+
return Err(Forbidden::Method);
188+
}
189+
190+
if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
191+
let headers = req_headers.to_str().map_err(|_| Forbidden::Header)?;
192+
for header in headers.split(',') {
193+
if !self.is_header_allowed(header.trim()) {
194+
return Err(Forbidden::Header);
195+
}
196+
}
197+
}
198+
199+
Ok(Validated::Preflight(origin.clone()))
200+
}
201+
(Some(origin), _) => {
202+
// Any other method, simply check for a valid origin...
203+
tracing::trace!("cors origin header: {:?}", origin);
204+
205+
if self.is_origin_allowed(origin) {
206+
Ok(Validated::Simple(origin.clone()))
207+
} else {
208+
Err(Forbidden::Origin)
209+
}
210+
}
211+
(None, _) => {
212+
// No `ORIGIN` header means this isn't CORS!
213+
Ok(Validated::NotCors)
214+
}
215+
}
216+
}
217+
218+
pub fn is_method_allowed(&self, header: &HeaderValue) -> bool {
219+
http::Method::from_bytes(header.as_bytes())
220+
.map(|method| self.cors.allowed_methods.contains(&method))
221+
.unwrap_or(false)
222+
}
223+
224+
pub fn is_header_allowed(&self, header: &str) -> bool {
225+
HeaderName::from_bytes(header.as_bytes())
226+
.map(|header| self.cors.allowed_headers.contains(&header))
227+
.unwrap_or(false)
228+
}
229+
230+
pub fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
231+
if let Some(ref allowed) = self.cors.origins {
232+
allowed.contains(origin)
233+
} else {
234+
true
235+
}
236+
}
237+
}
238+
239+
pub trait Seconds {
240+
fn seconds(self) -> u64;
241+
}
242+
243+
impl Seconds for u32 {
244+
fn seconds(self) -> u64 {
245+
self.into()
246+
}
247+
}
248+
249+
impl Seconds for ::std::time::Duration {
250+
fn seconds(self) -> u64 {
251+
self.as_secs()
252+
}
253+
}
254+
255+
pub trait IntoOrigin {
256+
fn into_origin(self) -> Origin;
257+
}
258+
259+
impl<'a> IntoOrigin for &'a str {
260+
fn into_origin(self) -> Origin {
261+
let mut parts = self.splitn(2, "://");
262+
let scheme = parts.next().expect("cors::into_origin: missing url scheme");
263+
let rest = parts.next().expect("cors::into_origin: missing url scheme");
264+
265+
Origin::try_from_parts(scheme, rest, None).expect("cors::into_origin: invalid Origin")
266+
}
267+
}

src/handler.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1+
use http::StatusCode;
12
use hyper::{Body, Request, Response};
2-
use std::{future::Future, path::PathBuf};
3+
use std::{future::Future, path::PathBuf, sync::Arc};
34

4-
use crate::{compression, control_headers, static_files};
5+
use crate::{compression, control_headers, cors, static_files};
56
use crate::{error_page, Error, Result};
67

78
// It defines options for a request handler.
89
pub struct RequestHandlerOpts {
910
pub root_dir: PathBuf,
1011
pub compression: bool,
1112
pub dir_listing: bool,
13+
pub cors: Option<Arc<cors::Configured>>,
1214
}
1315

1416
// It defines the main request handler for Hyper service request.
@@ -23,11 +25,27 @@ impl RequestHandler {
2325
) -> impl Future<Output = Result<Response<Body>, Error>> + Send + 'a {
2426
let method = req.method();
2527
let headers = req.headers();
28+
2629
let root_dir = self.opts.root_dir.as_path();
2730
let uri_path = req.uri().path();
2831
let dir_listing = self.opts.dir_listing;
2932

3033
async move {
34+
// CORS
35+
if self.opts.cors.is_some() {
36+
let cors = self.opts.cors.as_ref().unwrap();
37+
match cors.check_request(method, headers) {
38+
Ok(r) => {
39+
tracing::debug!("cors ok: {:?}", r);
40+
}
41+
Err(e) => {
42+
tracing::debug!("cors error kind: {:?}", e);
43+
return error_page::get_error_response(method, &StatusCode::FORBIDDEN);
44+
}
45+
};
46+
}
47+
48+
// Static files
3149
match static_files::handle_request(method, headers, root_dir, uri_path, dir_listing)
3250
.await
3351
{

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ extern crate anyhow;
66
pub mod compression;
77
pub mod config;
88
pub mod control_headers;
9+
pub mod cors;
910
pub mod error_page;
1011
pub mod handler;
1112
pub mod helpers;

src/server.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::handler::{RequestHandler, RequestHandlerOpts};
88
use crate::tls::{TlsAcceptor, TlsConfigBuilder};
99
use crate::Result;
1010
use crate::{config::Config, service::RouterService};
11-
use crate::{error_page, helpers, logger};
11+
use crate::{cors, error_page, helpers, logger};
1212

1313
/// Define a multi-thread HTTP or HTTP/2 web server.
1414
pub struct Server {
@@ -89,8 +89,6 @@ impl Server {
8989
.set(helpers::read_file_content(opts.page50x.as_ref()))
9090
.expect("page 50x is not initialized");
9191

92-
// TODO: CORS support
93-
9492
// Auto compression based on the `Accept-Encoding` header
9593
let compression = opts.compression;
9694
tracing::info!("auto compression compression: enabled={}", compression);
@@ -102,12 +100,16 @@ impl Server {
102100
// Spawn a new Tokio asynchronous server task with its given options
103101
let threads = self.threads;
104102

103+
// CORS support
104+
let cors = cors::new(opts.cors_allow_origins.trim().to_string());
105+
105106
// Create a service router for Hyper
106107
let router_service = RouterService::new(RequestHandler {
107108
opts: RequestHandlerOpts {
108109
root_dir,
109110
compression,
110111
dir_listing,
112+
cors,
111113
},
112114
});
113115

0 commit comments

Comments
 (0)